xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/ir/ir_views.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <c10/util/irange.h>
4 #include <torch/csrc/jit/ir/ir.h>
5 
6 namespace torch::jit {
7 
8 struct IfView {
IfViewIfView9   explicit IfView(Node* node) : node_(node) {
10     AT_ASSERT(node->kind() == ::c10::prim::If);
11   }
condIfView12   Value* cond() const {
13     return node_->input(0);
14   }
thenBlockIfView15   Block* thenBlock() const {
16     return node_->blocks().at(0);
17   }
elseBlockIfView18   Block* elseBlock() const {
19     return node_->blocks().at(1);
20   }
thenOutputsIfView21   ArrayRef<Value*> thenOutputs() const {
22     return thenBlock()->outputs();
23   }
elseOutputsIfView24   ArrayRef<Value*> elseOutputs() const {
25     return elseBlock()->outputs();
26   }
outputsIfView27   ArrayRef<Value*> outputs() const {
28     return node_->outputs();
29   }
nodeIfView30   Node* node() const {
31     return node_;
32   }
33   operator Node*() const {
34     return node_;
35   }
36 
permuteOutputsIfView37   void permuteOutputs(const std::vector<size_t>& new_output_order) {
38     node_->permuteOutputs(new_output_order);
39     thenBlock()->permuteOutputs(new_output_order);
40     elseBlock()->permuteOutputs(new_output_order);
41   }
42 
43  private:
44   Node* node_;
45 };
46 
47 struct LoopView {
LoopViewLoopView48   explicit LoopView(Node* node) : node_(node) {
49     AT_ASSERT(
50         node->kind() == ::c10::prim::Loop || node->kind() == ::c10::onnx::Loop);
51   }
bodyBlockLoopView52   Block* bodyBlock() const {
53     return node_->blocks().at(0);
54   }
condLoopView55   Value* cond() const {
56     return node_->input(0);
57   }
maxTripCountLoopView58   Value* maxTripCount() const {
59     return node_->input(0);
60   }
inputCondLoopView61   Value* inputCond() const {
62     return node_->input(1);
63   }
nextCondLoopView64   Value* nextCond() const {
65     return bodyBlock()->outputs().at(0);
66   }
currentTripCountLoopView67   Value* currentTripCount() const {
68     return bodyBlock()->inputs().at(0);
69   }
carriedInputsLoopView70   ArrayRef<Value*> carriedInputs() const {
71     // skip trip count and cond
72     return node_->inputs().slice(2);
73   }
carriedInputsWithCondLoopView74   ArrayRef<Value*> carriedInputsWithCond() const {
75     // skip trip count and cond
76     return node_->inputs().slice(1);
77   }
carriedOutputsLoopView78   ArrayRef<Value*> carriedOutputs() const {
79     return node_->outputs();
80   }
bodyCarriedInputsLoopView81   ArrayRef<Value*> bodyCarriedInputs() const {
82     // skip trip count and cond
83     return bodyBlock()->inputs().slice(1);
84   }
bodyCarriedOutputsLoopView85   ArrayRef<Value*> bodyCarriedOutputs() const {
86     return bodyBlock()->outputs().slice(1);
87   }
nodeLoopView88   Node* node() const {
89     return node_;
90   }
91   operator Node*() const {
92     return node_;
93   }
94 
permuteLoopCarriedLoopView95   void permuteLoopCarried(const std::vector<size_t>& new_output_order) {
96     node_->permuteOutputs(new_output_order);
97     // skip trip count and cond
98     node_->permuteInputs(adjustIndices(2, new_output_order));
99     auto adjusted_block_order = adjustIndices(1, new_output_order);
100     bodyBlock()->permuteOutputs(adjusted_block_order);
101     bodyBlock()->permuteInputs(adjusted_block_order);
102   }
103 
replaceMaxTripCountLoopView104   void replaceMaxTripCount(Value* new_max_trip_count) {
105     node_->replaceInput(0, new_max_trip_count);
106   }
replaceInputConditionLoopView107   void replaceInputCondition(Value* new_input_condition) {
108     node_->replaceInput(1, new_input_condition);
109   }
110 
111   // our way of encoding loops makes them difficult to turn back into python
112   // syntax. we have to check properties of the condition and trip count inputs
113   // to figure out which one it initially was. ModifiedLoops are not directly
114   // mappable to either For or While
115   enum LoopType { While, For, ModifiedLoop };
116 
loopTypeLoopView117   LoopType loopType() {
118     auto trip_count = toIValue(maxTripCount());
119     auto cond_input = toIValue(inputCond());
120     auto cond_next = toIValue(nextCond());
121 
122     bool condition_is_always_true =
123         cond_input && cond_input->toBool() && cond_next && cond_next->toBool();
124     bool trip_count_is_specified = !trip_count || // trip is not a constant
125         trip_count->toInt() !=
126             std::numeric_limits<int64_t>::max() || // it is a constant but not
127                                                    // the default one
128         !currentTripCount()
129              ->uses()
130              .empty(); // it is actually being used in the body.
131 
132     if (condition_is_always_true) {
133       // if the trip count was not specified this was a user-written while True:
134       return trip_count_is_specified ? For : While;
135     } else {
136       if (trip_count_is_specified) {
137         return ModifiedLoop;
138       }
139       return While;
140     }
141   }
142 
143  private:
144   Node* node_;
145 
146   // adjust index_ordering by adding indices 0 - thorugh adjust, and
147   // incrementing all existing inputs by adjust
adjustIndicesLoopView148   static std::vector<size_t> adjustIndices(
149       size_t adjust,
150       const std::vector<size_t>& index_ordering) {
151     std::vector<size_t> adjusted;
152     adjusted.reserve(adjust + index_ordering.size());
153     for (const auto i : c10::irange(adjust)) {
154       adjusted.push_back(i);
155     }
156     for (auto index : index_ordering) {
157       adjusted.push_back(index + adjust);
158     }
159     return adjusted;
160   }
161 };
162 } // namespace torch::jit
163