1 #pragma once
2
3 #include <algorithm>
4 #include <list>
5 #include <string>
6 #include <unordered_set>
7 #include <utility>
8 #include <vector>
9
10 #include <torch/csrc/jit/tensorexpr/expr.h>
11
12 namespace torch::jit::tensorexpr {
13
14 // The common base between all statement node.
15 class TORCH_API Stmt : public std::enable_shared_from_this<Stmt> {
16 public:
17 Stmt() = default;
18 virtual ~Stmt() = default;
19 virtual void accept(IRVisitor* visitor) = 0;
20 virtual StmtPtr accept_mutator(IRMutator* mutator) = 0;
21
get_parent()22 StmtPtr get_parent() const {
23 return parent_ ? parent_->getptr() : nullptr;
24 }
25
26 /*
27 * Make a deep copy of the given statement.
28 *
29 * All statements and expressions used in children of the statement are
30 * cloned. Note that the variables are not deep-copied since they are
31 * immutable.
32 */
33 static StmtPtr clone(const StmtPtr& s);
34
35 protected:
set_parent(const StmtPtr & s,Stmt * new_parent)36 static void set_parent(const StmtPtr& s, Stmt* new_parent) {
37 s->parent_ = new_parent;
38 }
getptr()39 std::shared_ptr<Stmt> getptr() {
40 return shared_from_this();
41 }
42
43 private:
44 Stmt* parent_ = nullptr;
45 };
46
47 template <class Op>
48 class StmtNode : public Stmt {
49 public:
50 using StmtNodeBase = StmtNode<Op>;
accept(IRVisitor * visitor)51 void accept(IRVisitor* visitor) override {
52 visitor->visit(static_to<Op>(getptr()));
53 }
54 StmtPtr accept_mutator(IRMutator* mutator) override;
55 StmtNode() = default;
56 };
57
58 template <class Op>
accept_mutator(IRMutator * mutator)59 StmtPtr StmtNode<Op>::accept_mutator(IRMutator* mutator) {
60 return mutator->mutate(static_to<Op>(getptr()));
61 }
62
63 // Concrete Stmt classes
64 class TORCH_API Block : public StmtNode<Block> {
65 public:
make(const std::vector<StmtPtr> & stmts)66 static BlockPtr make(const std::vector<StmtPtr>& stmts) {
67 std::vector<StmtPtr> valid_stmts;
68 for (auto& stmt : stmts) {
69 if (!stmt) {
70 continue;
71 }
72 valid_stmts.push_back(stmt);
73 }
74 if (valid_stmts.empty()) {
75 return nullptr;
76 }
77 return alloc<Block>(valid_stmts);
78 }
79
nstmts()80 size_t nstmts() const {
81 return stmts_.size();
82 }
empty()83 bool empty() const {
84 return stmts_.empty();
85 }
86
prepend_stmt(StmtPtr s)87 void prepend_stmt(StmtPtr s) {
88 if (s->get_parent()) {
89 throw malformed_input(
90 "Block prepend Stmt with existing parent", std::move(s));
91 }
92
93 stmts_.push_front(s);
94 set_parent(s, this);
95 }
append_stmt(StmtPtr s)96 void append_stmt(StmtPtr s) {
97 if (s->get_parent()) {
98 throw malformed_input(
99 "Block append Stmt with existing parent", std::move(s));
100 }
101
102 stmts_.push_back(s);
103 set_parent(s, this);
104 }
105
insert_stmt_before(StmtPtr s,const StmtPtr & before)106 void insert_stmt_before(StmtPtr s, const StmtPtr& before) {
107 if (s->get_parent()) {
108 throw malformed_input(
109 "Block append Stmt with existing parent", std::move(s));
110 }
111
112 auto pos = std::find(stmts_.begin(), stmts_.end(), before);
113 if (pos == stmts_.end()) {
114 throw malformed_input(
115 "Inserting after statement that is not in block", std::move(s));
116 }
117
118 stmts_.insert(pos, s);
119 set_parent(s, this);
120 }
121
insert_stmt_after(StmtPtr s,const StmtPtr & after)122 void insert_stmt_after(StmtPtr s, const StmtPtr& after) {
123 if (s->get_parent()) {
124 throw malformed_input(
125 "Block append Stmt with existing parent", std::move(s));
126 }
127
128 auto pos = std::find(stmts_.begin(), stmts_.end(), after);
129 if (pos == stmts_.end()) {
130 throw malformed_input(
131 "Inserting after statement that is not in block", std::move(s));
132 }
133
134 ++pos;
135
136 stmts_.insert(pos, s);
137 set_parent(s, this);
138 }
139
replace_stmt(const StmtPtr & old_stmt,StmtPtr new_stmt)140 bool replace_stmt(const StmtPtr& old_stmt, StmtPtr new_stmt) {
141 if (new_stmt->get_parent()) {
142 throw malformed_input(
143 "Block replace Stmt with existing parent", std::move(new_stmt));
144 }
145
146 auto pos = std::find(stmts_.begin(), stmts_.end(), old_stmt);
147 if (pos == stmts_.end()) {
148 return false;
149 }
150 stmts_.insert(pos, new_stmt);
151 stmts_.erase(pos);
152 set_parent(old_stmt, nullptr);
153 set_parent(new_stmt, this);
154 return true;
155 }
156
157 // Creates a new block by cloning `this` block and replacing the given
158 // statement with a new statement. Note that `old_stmt` refers to a statement
159 // in `this` block. If the `old_stmt` is not found, it will return `nullptr`.
clone_and_replace(const StmtPtr & old_stmt,StmtPtr new_stmt)160 BlockPtr clone_and_replace(const StmtPtr& old_stmt, StmtPtr new_stmt) {
161 if (new_stmt->get_parent()) {
162 throw malformed_input(
163 "Block replace Stmt with existing parent", std::move(new_stmt));
164 }
165
166 std::vector<StmtPtr> stmts(stmts_.begin(), stmts_.end());
167 std::vector<StmtPtr> cloned_stmts(stmts.size());
168 bool found = false;
169 for (int i = 0; i < static_cast<int>(stmts.size()); ++i) {
170 if (stmts[i] == old_stmt) {
171 found = true;
172 cloned_stmts[i] = new_stmt;
173 } else {
174 cloned_stmts[i] = Stmt::clone(stmts[i]);
175 }
176 }
177 if (!found) {
178 return nullptr;
179 }
180 return alloc<Block>(cloned_stmts);
181 }
182
remove_stmt(const StmtPtr & stmt)183 bool remove_stmt(const StmtPtr& stmt) {
184 auto pos = std::find(stmts_.begin(), stmts_.end(), stmt);
185 if (pos == stmts_.end()) {
186 return false;
187 }
188
189 set_parent(stmt, nullptr);
190 stmts_.erase(pos);
191 return true;
192 }
193
stmts()194 std::list<StmtPtr> stmts() const {
195 return stmts_;
196 }
197
clear()198 void clear() {
199 for (const auto& s : stmts_) {
200 set_parent(s, nullptr);
201 }
202 stmts_.clear();
203 }
204
set_stmts(const std::vector<StmtPtr> & stmts)205 void set_stmts(const std::vector<StmtPtr>& stmts) {
206 clear();
207 init(stmts);
208 }
209
Block(const std::vector<StmtPtr> & stmts)210 explicit Block(const std::vector<StmtPtr>& stmts) {
211 init(stmts);
212 }
213
214 typedef std::list<StmtPtr>::iterator iterator;
215 typedef std::list<StmtPtr>::const_iterator const_iterator;
216
begin()217 iterator begin() {
218 return stmts_.begin();
219 }
220
begin()221 const_iterator begin() const {
222 return stmts_.begin();
223 }
224
end()225 iterator end() {
226 return stmts_.end();
227 }
228
end()229 const_iterator end() const {
230 return stmts_.end();
231 }
232
front()233 StmtPtr front() {
234 return stmts_.front();
235 }
236
front()237 StmtPtr front() const {
238 return stmts_.front();
239 }
240
back()241 StmtPtr back() {
242 return stmts_.back();
243 }
244
back()245 StmtPtr back() const {
246 return stmts_.back();
247 }
248
splice(Block::iterator it,const BlockPtr & other)249 void splice(Block::iterator it, const BlockPtr& other) {
250 for (const StmtPtr& s : *other) {
251 set_parent(s, this);
252 }
253
254 stmts_.splice(it, other->stmts_);
255 }
256
getSharedParent(StmtPtr p1,StmtPtr p2)257 static BlockPtr getSharedParent(StmtPtr p1, StmtPtr p2) {
258 std::unordered_set<BlockPtr> enclosing;
259
260 StmtPtr p1_p = std::move(p1);
261 while (p1_p) {
262 if (BlockPtr b = to<Block>(p1_p)) {
263 if (b) {
264 enclosing.insert(b);
265 }
266 }
267 p1_p = p1_p->get_parent();
268 }
269
270 StmtPtr p2_p = std::move(p2);
271 while (p2_p) {
272 if (BlockPtr b = to<Block>(p2_p)) {
273 if (enclosing.count(b) != 0) {
274 return b;
275 }
276 }
277 p2_p = p2_p->get_parent();
278 }
279
280 return nullptr;
281 }
282
283 // returns the immediate child containing statement s.
getEnclosedRoot(StmtPtr s)284 StmtPtr getEnclosedRoot(StmtPtr s) const {
285 while (s && s->get_parent().get() != this) {
286 s = s->get_parent();
287 }
288 return s;
289 }
290
291 private:
292 std::list<StmtPtr> stmts_;
293
init(const std::vector<StmtPtr> & stmts)294 void init(const std::vector<StmtPtr>& stmts) {
295 for (const StmtPtr& s : stmts) {
296 if (!s) {
297 continue;
298 }
299 if (!s->get_parent()) {
300 // If we get here, it's a bug, but we cannot throw an error from a
301 // constructor. But IR verifier would catch this.
302 set_parent(s, this);
303 }
304
305 stmts_.push_back(s);
306 }
307 }
308 };
309
310 class TORCH_API Store : public StmtNode<Store> {
311 public:
base_handle()312 VarPtr base_handle() const {
313 return buf_->base_handle();
314 }
indices()315 std::vector<ExprPtr> indices() const {
316 return indices_;
317 }
flat_index()318 ExprPtr flat_index() const {
319 TORCH_CHECK(indices_.size() == 1, "Indices haven't been flattened.");
320 return indices_[0];
321 }
value()322 ExprPtr value() const {
323 return value_;
324 }
buf()325 BufPtr buf() const {
326 return buf_;
327 }
328
set_buf(BufPtr buf)329 void set_buf(BufPtr buf) {
330 buf_ = std::move(buf);
331 }
332
set_indices(std::vector<ExprPtr> indices)333 void set_indices(std::vector<ExprPtr> indices) {
334 indices_ = std::move(indices);
335 }
336
set_value(ExprPtr value)337 void set_value(ExprPtr value) {
338 value_ = std::move(value);
339 }
340
341 static StorePtr make(
342 const BufHandle& buf,
343 const std::vector<ExprHandle>& indices,
344 const ExprHandle& value);
345
346 Store(BufPtr buf, std::vector<ExprPtr> indices, ExprPtr value);
347
348 private:
349 BufPtr buf_;
350 std::vector<ExprPtr> indices_;
351 ExprPtr value_;
352 };
353
354 // Allocate a buffer of given shapes and dtypes and bind it with the given
355 // buffer var. The life span is at most through the current program, until it is
356 // explicitly freed. An unfreed memory is likely considered an error.
357 class TORCH_API Allocate : public StmtNode<Allocate> {
358 public:
make(const BufHandle & buf_handle)359 static AllocatePtr make(const BufHandle& buf_handle) {
360 return alloc<Allocate>(buf_handle.node());
361 }
362
buffer_var()363 VarPtr buffer_var() const {
364 return buf_->base_handle();
365 }
366
dtype()367 Dtype dtype() const {
368 return buf_->dtype();
369 }
370
dims()371 const std::vector<ExprPtr> dims() const {
372 return buf_->dims();
373 }
374
buf()375 BufPtr buf() const {
376 return buf_;
377 }
378
set_buf(BufPtr buf)379 void set_buf(BufPtr buf) {
380 buf_ = std::move(buf);
381 }
382
Allocate(BufPtr buf)383 explicit Allocate(BufPtr buf) : buf_(std::move(buf)) {}
384
385 private:
386 BufPtr buf_;
387 // TODO: add memory types.
388 };
389
390 // PlacementAllocate is a variation of the Allocate operator in NNC IR. It does
391 // not allocate memory but reuse the memory of another buffer for the given
392 // buffer.
393 class TORCH_API PlacementAllocate : public StmtNode<PlacementAllocate> {
394 public:
make(const BufHandle & buf_handle,const BufHandle & buf_handle_to_reuse)395 static PlacementAllocatePtr make(
396 const BufHandle& buf_handle,
397 const BufHandle& buf_handle_to_reuse) {
398 return alloc<PlacementAllocate>(
399 buf_handle.node(), buf_handle_to_reuse.node());
400 }
401
buf()402 BufPtr buf() const {
403 return buf_;
404 }
405
buf_to_reuse()406 BufPtr buf_to_reuse() const {
407 return buf_to_reuse_;
408 }
409
set_buf(BufPtr buf)410 void set_buf(BufPtr buf) {
411 buf_ = std::move(buf);
412 }
413
set_buf_to_reuse(BufPtr buf)414 void set_buf_to_reuse(BufPtr buf) {
415 buf_to_reuse_ = std::move(buf);
416 }
417
PlacementAllocate(BufPtr buf,BufPtr buf_to_reuse)418 explicit PlacementAllocate(BufPtr buf, BufPtr buf_to_reuse)
419 : buf_(std::move(buf)), buf_to_reuse_(std::move(buf_to_reuse)) {}
420
421 private:
422 BufPtr buf_;
423 BufPtr buf_to_reuse_;
424 };
425
426 // Free the specific buffer. It is an error.
427 class TORCH_API Free : public StmtNode<Free> {
428 public:
make(const BufHandle & buf_handle)429 static FreePtr make(const BufHandle& buf_handle) {
430 return alloc<Free>(buf_handle.node());
431 }
432
buffer_var()433 VarPtr buffer_var() const {
434 return buf_->base_handle();
435 }
436
buf()437 BufPtr buf() const {
438 return buf_;
439 }
440
set_buf(BufPtr buf)441 void set_buf(BufPtr buf) {
442 buf_ = std::move(buf);
443 }
444
Free(BufPtr buf)445 explicit Free(BufPtr buf) : buf_(std::move(buf)) {}
446
447 private:
448 BufPtr buf_;
449 };
450
451 class TORCH_API FreeExt : public StmtNode<FreeExt> {
452 public:
453 static FreeExtPtr make(const std::vector<BufHandle>& bufs);
454
bufs()455 std::vector<BufPtr> bufs() const {
456 return bufs_;
457 }
458
set_bufs(std::vector<BufPtr> bufs)459 void set_bufs(std::vector<BufPtr> bufs) {
460 bufs_ = std::move(bufs);
461 }
462
FreeExt(std::vector<BufPtr> bufs)463 explicit FreeExt(std::vector<BufPtr> bufs) : bufs_(std::move(bufs)) {}
464
465 private:
466 std::vector<BufPtr> bufs_;
467 };
468
469 class TORCH_API Let : public StmtNode<Let> {
470 public:
make(const VarHandle & var,const ExprHandle & val)471 static LetPtr make(const VarHandle& var, const ExprHandle& val) {
472 return alloc<Let>(var.node(), val.node());
473 }
474
Let(VarPtr var,ExprPtr val)475 Let(VarPtr var, ExprPtr val) : var_(std::move(var)), val_(std::move(val)) {}
476
var()477 VarPtr var() const {
478 return var_;
479 }
480
value()481 ExprPtr value() const {
482 return val_;
483 }
484
set_var(VarPtr var)485 void set_var(VarPtr var) {
486 var_ = std::move(var);
487 }
488
set_val(ExprPtr val)489 void set_val(ExprPtr val) {
490 val_ = std::move(val);
491 }
492
493 private:
494 VarPtr var_;
495 ExprPtr val_;
496 };
497
498 class TORCH_API Cond : public StmtNode<Cond> {
499 public:
make(const ExprHandle & condition,const StmtPtr & true_stmt,const StmtPtr & false_stmt)500 static CondPtr make(
501 const ExprHandle& condition,
502 const StmtPtr& true_stmt,
503 const StmtPtr& false_stmt) {
504 return alloc<Cond>(condition.node(), true_stmt, false_stmt);
505 }
506
condition()507 ExprPtr condition() const {
508 return condition_;
509 }
510
true_stmt()511 BlockPtr true_stmt() const {
512 return true_stmt_;
513 }
514
false_stmt()515 BlockPtr false_stmt() const {
516 return false_stmt_;
517 }
518
set_condition(ExprPtr condition)519 void set_condition(ExprPtr condition) {
520 condition_ = std::move(condition);
521 }
522
set_true_stmt(StmtPtr true_stmt)523 void set_true_stmt(StmtPtr true_stmt) {
524 if (true_stmt) {
525 BlockPtr b = to<Block>(true_stmt);
526 if (!b) {
527 b = alloc<Block>(std::vector<StmtPtr>({std::move(true_stmt)}));
528 }
529 true_stmt_ = b;
530 set_parent(true_stmt_, this);
531 }
532 }
533
set_false_stmt(StmtPtr false_stmt)534 void set_false_stmt(StmtPtr false_stmt) {
535 if (false_stmt) {
536 BlockPtr b = to<Block>(false_stmt);
537 if (!b) {
538 b = alloc<Block>(std::vector<StmtPtr>({std::move(false_stmt)}));
539 }
540 false_stmt_ = b;
541 set_parent(false_stmt_, this);
542 }
543 }
544
Cond(ExprPtr condition,StmtPtr true_stmt,StmtPtr false_stmt)545 Cond(ExprPtr condition, StmtPtr true_stmt, StmtPtr false_stmt)
546 : condition_(std::move(condition)) {
547 set_true_stmt(std::move(true_stmt));
548 set_false_stmt(std::move(false_stmt));
549 }
550
cloneWithNewBodies(const StmtPtr & true_stmt,const StmtPtr & false_stmt)551 CondPtr cloneWithNewBodies(
552 const StmtPtr& true_stmt,
553 const StmtPtr& false_stmt) {
554 return alloc<Cond>(condition_, true_stmt, false_stmt);
555 }
556
cloneWithNewBody(const StmtPtr & true_stmt)557 CondPtr cloneWithNewBody(const StmtPtr& true_stmt) {
558 return alloc<Cond>(condition_, true_stmt, nullptr);
559 }
560
561 private:
562 ExprPtr condition_;
563 BlockPtr true_stmt_ = nullptr;
564 BlockPtr false_stmt_ = nullptr;
565 };
566
567 class TORCH_API LoopOptions {
568 public:
569 enum {
570 IDX_UNSET = -1,
571 IDX_X = 0,
572 IDX_Y = 1,
573 IDX_Z = 2,
574 IDX_W = 3,
575 IDX_MAX = IDX_W,
576 };
577 // GPU Block Index
is_gpu_block_index()578 bool is_gpu_block_index() const {
579 return gpu_block_index_ != IDX_UNSET;
580 }
581
gpu_block_index()582 int gpu_block_index() const {
583 return gpu_block_index_;
584 }
585
gpu_block_index_str()586 std::string gpu_block_index_str() const {
587 if (!is_gpu_block_index()) {
588 throw malformed_input("Has no GPU block index");
589 }
590
591 // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
592 static const char* kBlockIndexNames[] = {
593 "blockIdx.x",
594 "blockIdx.y",
595 "blockIdx.z",
596 "blockIdx.w",
597 };
598
599 if (gpu_block_index_ < IDX_X || gpu_block_index_ > IDX_MAX) {
600 throw malformed_input("invalid GPU block index");
601 }
602
603 return kBlockIndexNames[gpu_block_index_];
604 }
605
set_gpu_block_index(int index)606 void set_gpu_block_index(int index) {
607 if (index == IDX_UNSET) {
608 gpu_block_index_ = IDX_UNSET;
609 }
610
611 if (is_gpu_thread_index()) {
612 throw std::runtime_error("Cannot set both gpu block and thread index");
613 }
614 if (is_gpu_block_index() && gpu_block_index() != index) {
615 throw std::runtime_error("Cannot set a previously set block index");
616 }
617 gpu_block_index_ = index;
618 }
619
620 // GPU Thread Index
is_gpu_thread_index()621 bool is_gpu_thread_index() const {
622 return gpu_thread_index() != IDX_UNSET;
623 }
624
gpu_thread_index()625 int gpu_thread_index() const {
626 return gpu_thread_index_;
627 }
628
gpu_thread_index_str()629 std::string gpu_thread_index_str() const {
630 if (!is_gpu_thread_index()) {
631 throw malformed_input("has no GPU thread index");
632 }
633
634 // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
635 static const char* kThreadIndexNames[] = {
636 "threadIdx.x", "threadIdx.y", "threadIdx.z", "threadIdx.w"};
637
638 if (gpu_thread_index_ < IDX_X || gpu_thread_index_ > IDX_MAX) {
639 throw malformed_input("invalid GPU thread index");
640 }
641
642 return kThreadIndexNames[gpu_thread_index_];
643 }
644
set_gpu_thread_index(int index)645 void set_gpu_thread_index(int index) {
646 if (index == IDX_UNSET) {
647 gpu_thread_index_ = IDX_UNSET;
648 }
649
650 if (is_gpu_block_index()) {
651 throw std::runtime_error("Cannot set both gpu thread and block index");
652 }
653 if (is_gpu_thread_index() && gpu_thread_index() != index) {
654 throw std::runtime_error("Cannot set a previously set thread index");
655 }
656 gpu_thread_index_ = index;
657 }
658
set_parallel()659 void set_parallel() {
660 is_parallel_ = true;
661 }
662
is_parallel()663 bool is_parallel() const {
664 return is_parallel_;
665 }
666
ToString()667 std::string ToString() const {
668 if (is_gpu_block_index()) {
669 return gpu_block_index_str();
670 } else if (is_gpu_thread_index()) {
671 return gpu_thread_index_str();
672 } else if (is_parallel()) {
673 return "parallel";
674 }
675 return "";
676 }
677
isDefault()678 bool isDefault() const {
679 return gpu_block_index_ == IDX_UNSET && gpu_thread_index_ == IDX_UNSET &&
680 !is_parallel_;
681 }
682
set_buffer_mapping(const std::unordered_map<std::string,BufPtr> & map)683 void set_buffer_mapping(const std::unordered_map<std::string, BufPtr>& map) {
684 map_input_to_tensor_bufs_ = map;
685 }
686
get_buffer_mapping()687 std::unordered_map<std::string, BufPtr> get_buffer_mapping() const {
688 return map_input_to_tensor_bufs_;
689 }
690
691 private:
692 int gpu_block_index_{IDX_UNSET};
693 int gpu_thread_index_{IDX_UNSET};
694 bool is_parallel_{false};
695 std::unordered_map<std::string, BufPtr> map_input_to_tensor_bufs_;
696 };
697
698 class TORCH_API For : public StmtNode<For> {
699 public:
var()700 VarPtr var() const {
701 return var_;
702 }
start()703 ExprPtr start() const {
704 return start_;
705 }
stop()706 ExprPtr stop() const {
707 return stop_;
708 }
body()709 BlockPtr body() const {
710 return body_;
711 }
make(const VarHandle & var,const ExprHandle & start,const ExprHandle & stop,const StmtPtr & body)712 static ForPtr make(
713 const VarHandle& var,
714 const ExprHandle& start,
715 const ExprHandle& stop,
716 const StmtPtr& body) {
717 if (!body) {
718 return nullptr;
719 }
720 return alloc<For>(var.node(), start.node(), stop.node(), body);
721 }
make(const VarHandle & var,const ExprHandle & start,const ExprHandle & stop,const StmtPtr & body,const LoopOptions & loop_options)722 static ForPtr make(
723 const VarHandle& var,
724 const ExprHandle& start,
725 const ExprHandle& stop,
726 const StmtPtr& body,
727 const LoopOptions& loop_options) {
728 if (!body) {
729 return nullptr;
730 }
731 return alloc<For>(
732 var.node(), start.node(), stop.node(), body, loop_options);
733 }
loop_options()734 const LoopOptions loop_options() const {
735 return loop_options_;
736 }
737
For(VarPtr var,ExprPtr start,ExprPtr stop,StmtPtr body)738 For(VarPtr var, ExprPtr start, ExprPtr stop, StmtPtr body)
739 : var_(std::move(var)), start_(std::move(start)), stop_(std::move(stop)) {
740 BlockPtr b = to<Block>(body);
741 if (!b) {
742 b = alloc<Block>(std::vector<StmtPtr>({std::move(body)}));
743 }
744 body_ = b;
745 set_parent(body_, this);
746 }
747
For(VarPtr var,ExprPtr start,ExprPtr stop,StmtPtr body,LoopOptions loop_options)748 For(VarPtr var,
749 ExprPtr start,
750 ExprPtr stop,
751 StmtPtr body,
752 LoopOptions loop_options)
753 : var_(std::move(var)),
754 start_(std::move(start)),
755 stop_(std::move(stop)),
756 loop_options_(std::move(loop_options)) {
757 if (!var_) {
758 throw malformed_input("invalid Var in For loop");
759 } else if (!start_) {
760 throw malformed_input("invalid Start in For loop");
761 } else if (!stop_) {
762 throw malformed_input("invalid Stop in For loop");
763 } else if (!body || body->get_parent()) {
764 throw malformed_input("invalid Body in For loop");
765 }
766
767 BlockPtr b = to<Block>(body);
768 if (!b) {
769 b = alloc<Block>(std::vector<StmtPtr>({std::move(body)}));
770 }
771 body_ = b;
772 set_parent(body_, this);
773 }
774
set_gpu_block_index(int block_index)775 void set_gpu_block_index(int block_index) {
776 loop_options_.set_gpu_block_index(block_index);
777 }
778
set_gpu_thread_index(int thread_index)779 void set_gpu_thread_index(int thread_index) {
780 loop_options_.set_gpu_thread_index(thread_index);
781 }
782
set_parallel()783 void set_parallel() {
784 loop_options_.set_parallel();
785 }
786
is_parallel()787 bool is_parallel() const {
788 return loop_options_.is_parallel();
789 }
790
set_buffer_map(const std::unordered_map<std::string,BufPtr> & map)791 void set_buffer_map(const std::unordered_map<std::string, BufPtr>& map) {
792 loop_options_.set_buffer_mapping(map);
793 }
794
cloneWithNewBody(const StmtPtr & body)795 ForPtr cloneWithNewBody(const StmtPtr& body) const {
796 return alloc<For>(var_, start_, stop_, body, loop_options_);
797 }
798
removeBody()799 BlockPtr removeBody() {
800 auto res = body_;
801 set_parent(res, nullptr);
802 body_ = nullptr;
803 return res;
804 }
805
set_body(StmtPtr body)806 void set_body(StmtPtr body) {
807 BlockPtr b = to<Block>(body);
808 if (!b) {
809 b = alloc<Block>(std::vector<StmtPtr>({std::move(body)}));
810 }
811 body_ = b;
812 set_parent(body_, this);
813 }
814
set_start(ExprPtr start)815 void set_start(ExprPtr start) {
816 start_ = std::move(start);
817 }
818
set_stop(ExprPtr stop)819 void set_stop(ExprPtr stop) {
820 stop_ = std::move(stop);
821 }
822
set_var(VarPtr var)823 void set_var(VarPtr var) {
824 var_ = std::move(var);
825 }
826
827 private:
828 VarPtr var_;
829 ExprPtr start_;
830 ExprPtr stop_;
831 BlockPtr body_;
832 LoopOptions loop_options_;
833 };
834
835 // A backend specific IR Node that implements atomic-add.
836 // This node could only shows up as an internal with GPU backends.
837 // TODO: move to this an internal IR.
838 // TODO: make IR nodes extensible.
839 class TORCH_API AtomicAdd : public StmtNode<AtomicAdd> {
840 public:
AtomicAdd(BufPtr buf,std::vector<ExprPtr> indices,ExprPtr value)841 AtomicAdd(BufPtr buf, std::vector<ExprPtr> indices, ExprPtr value)
842 : buf_(std::move(buf)),
843 indices_(std::move(indices)),
844 value_(std::move(value)) {}
845
base_handle()846 VarPtr base_handle() const {
847 return buf_->base_handle();
848 }
849
buf()850 BufPtr buf() const {
851 return buf_;
852 }
853
flat_index()854 ExprPtr flat_index() const {
855 TORCH_CHECK(indices_.size() == 1, "Indices haven't been flattened.");
856 return indices_[0];
857 }
858
value()859 ExprPtr value() const {
860 return value_;
861 }
862
indices()863 const std::vector<ExprPtr>& indices() const {
864 return indices_;
865 }
866
set_buf(BufPtr buf)867 void set_buf(BufPtr buf) {
868 buf_ = std::move(buf);
869 }
870
set_indices(std::vector<ExprPtr> indices)871 void set_indices(std::vector<ExprPtr> indices) {
872 indices_ = std::move(indices);
873 }
874
set_value(ExprPtr value)875 void set_value(ExprPtr value) {
876 value_ = std::move(value);
877 }
878
879 private:
880 BufPtr buf_;
881 std::vector<ExprPtr> indices_;
882 ExprPtr value_;
883 };
884
885 class TORCH_API SyncThreads : public StmtNode<SyncThreads> {
886 public:
887 SyncThreads() = default;
888 };
889
890 /*
891 * ExternalCall statement represents a call to an external function that would
892 * compute the contents of the output buffer. An ExternalCall statement consists
893 * of:
894 * 1) output buffer - the buffer that'll be initialized by the call
895 * 2) external function name - a key from the NNC function registry to lookup
896 * the actual function to call
897 * 3) buffer arguments - the input buffers used by the function
898 * 4) non-buffer arguments - scalar arguments to pass to the function
899 *
900 * An example:
901 * A = nnc_conv2d(buf_args={Input, Weight, Bias}, args={1})
902 * Here 'A' is the output buffer, "nnc_conv2d" is the function name, the buffer
903 * arguments are 'Input', 'Weight', and 'Bias', and there is a single non-buffer
904 * argument - 1.
905 *
906 * The semantics of the scalar arguments is defined solely by the implementation
907 * of the external function.
908 */
909 class TORCH_API ExternalCall : public StmtNode<ExternalCall> {
910 public:
911 static ExternalCallPtr make(
912 BufHandle buf,
913 const std::string& func_name,
914 const std::vector<BufHandle>& buf_args,
915 const std::vector<ExprHandle>& args);
916
buf()917 BufPtr buf() const {
918 return buf_;
919 }
920
func_name()921 std::string func_name() const {
922 return func_name_;
923 }
924
buf_args()925 std::vector<BufPtr> buf_args() const {
926 return buf_args_;
927 }
928
args()929 std::vector<ExprPtr> args() const {
930 return args_;
931 }
932
set_buf(BufPtr buf)933 void set_buf(BufPtr buf) {
934 buf_ = std::move(buf);
935 }
936
set_buf_args(std::vector<BufPtr> buf_args)937 void set_buf_args(std::vector<BufPtr> buf_args) {
938 buf_args_ = std::move(buf_args);
939 }
940
set_args(std::vector<ExprPtr> args)941 void set_args(std::vector<ExprPtr> args) {
942 args_ = std::move(args);
943 }
944
ExternalCall(BufPtr buf,std::string func_name,std::vector<BufPtr> buf_args,std::vector<ExprPtr> args)945 ExternalCall(
946 BufPtr buf,
947 std::string func_name,
948 std::vector<BufPtr> buf_args,
949 std::vector<ExprPtr> args)
950 : buf_(std::move(buf)),
951 func_name_(std::move(func_name)),
952 buf_args_(std::move(buf_args)),
953 args_(std::move(args)) {}
954
955 private:
956 BufPtr buf_;
957 std::string func_name_;
958 std::vector<BufPtr> buf_args_;
959 std::vector<ExprPtr> args_;
960 };
961
962 class TORCH_API ExternalCallWithAlloc : public StmtNode<ExternalCallWithAlloc> {
963 public:
964 static ExternalCallWithAllocPtr make(
965 const std::string& func_name,
966 const std::vector<BufHandle>& buf_out_args,
967 const std::vector<BufHandle>& buf_args,
968 const std::vector<ExprHandle>& args);
969
buf_out_args()970 std::vector<BufPtr> buf_out_args() const {
971 return buf_out_args_;
972 }
973
func_name()974 std::string func_name() const {
975 return func_name_;
976 }
977
buf_args()978 std::vector<BufPtr> buf_args() const {
979 return buf_args_;
980 }
981
args()982 std::vector<ExprPtr> args() const {
983 return args_;
984 }
985
set_buf_out_args(std::vector<BufPtr> buf_out_args)986 void set_buf_out_args(std::vector<BufPtr> buf_out_args) {
987 buf_out_args_ = std::move(buf_out_args);
988 }
989
set_buf_args(std::vector<BufPtr> buf_args)990 void set_buf_args(std::vector<BufPtr> buf_args) {
991 buf_args_ = std::move(buf_args);
992 }
993
set_args(std::vector<ExprPtr> args)994 void set_args(std::vector<ExprPtr> args) {
995 args_ = std::move(args);
996 }
997
ExternalCallWithAlloc(std::string func_name,std::vector<BufPtr> buf_out_args,std::vector<BufPtr> buf_args,std::vector<ExprPtr> args)998 ExternalCallWithAlloc(
999 std::string func_name,
1000 std::vector<BufPtr> buf_out_args,
1001 std::vector<BufPtr> buf_args,
1002 std::vector<ExprPtr> args)
1003 : func_name_(std::move(func_name)),
1004 buf_out_args_(std::move(buf_out_args)),
1005 buf_args_(std::move(buf_args)),
1006 args_(std::move(args)) {}
1007
1008 private:
1009 std::string func_name_;
1010 std::vector<BufPtr> buf_out_args_;
1011 std::vector<BufPtr> buf_args_;
1012 std::vector<ExprPtr> args_;
1013 };
1014
1015 } // namespace torch::jit::tensorexpr
1016