1 #pragma once 2 3 #include <c10/util/irange.h> 4 #include <torch/csrc/jit/ir/ir.h> 5 6 namespace torch::jit { 7 8 struct IfView { IfViewIfView9 explicit IfView(Node* node) : node_(node) { 10 AT_ASSERT(node->kind() == ::c10::prim::If); 11 } condIfView12 Value* cond() const { 13 return node_->input(0); 14 } thenBlockIfView15 Block* thenBlock() const { 16 return node_->blocks().at(0); 17 } elseBlockIfView18 Block* elseBlock() const { 19 return node_->blocks().at(1); 20 } thenOutputsIfView21 ArrayRef<Value*> thenOutputs() const { 22 return thenBlock()->outputs(); 23 } elseOutputsIfView24 ArrayRef<Value*> elseOutputs() const { 25 return elseBlock()->outputs(); 26 } outputsIfView27 ArrayRef<Value*> outputs() const { 28 return node_->outputs(); 29 } nodeIfView30 Node* node() const { 31 return node_; 32 } 33 operator Node*() const { 34 return node_; 35 } 36 permuteOutputsIfView37 void permuteOutputs(const std::vector<size_t>& new_output_order) { 38 node_->permuteOutputs(new_output_order); 39 thenBlock()->permuteOutputs(new_output_order); 40 elseBlock()->permuteOutputs(new_output_order); 41 } 42 43 private: 44 Node* node_; 45 }; 46 47 struct LoopView { LoopViewLoopView48 explicit LoopView(Node* node) : node_(node) { 49 AT_ASSERT( 50 node->kind() == ::c10::prim::Loop || node->kind() == ::c10::onnx::Loop); 51 } bodyBlockLoopView52 Block* bodyBlock() const { 53 return node_->blocks().at(0); 54 } condLoopView55 Value* cond() const { 56 return node_->input(0); 57 } maxTripCountLoopView58 Value* maxTripCount() const { 59 return node_->input(0); 60 } inputCondLoopView61 Value* inputCond() const { 62 return node_->input(1); 63 } nextCondLoopView64 Value* nextCond() const { 65 return bodyBlock()->outputs().at(0); 66 } currentTripCountLoopView67 Value* currentTripCount() const { 68 return bodyBlock()->inputs().at(0); 69 } carriedInputsLoopView70 ArrayRef<Value*> carriedInputs() const { 71 // skip trip count and cond 72 return node_->inputs().slice(2); 73 } carriedInputsWithCondLoopView74 ArrayRef<Value*> carriedInputsWithCond() const { 75 // skip trip count and cond 76 return node_->inputs().slice(1); 77 } carriedOutputsLoopView78 ArrayRef<Value*> carriedOutputs() const { 79 return node_->outputs(); 80 } bodyCarriedInputsLoopView81 ArrayRef<Value*> bodyCarriedInputs() const { 82 // skip trip count and cond 83 return bodyBlock()->inputs().slice(1); 84 } bodyCarriedOutputsLoopView85 ArrayRef<Value*> bodyCarriedOutputs() const { 86 return bodyBlock()->outputs().slice(1); 87 } nodeLoopView88 Node* node() const { 89 return node_; 90 } 91 operator Node*() const { 92 return node_; 93 } 94 permuteLoopCarriedLoopView95 void permuteLoopCarried(const std::vector<size_t>& new_output_order) { 96 node_->permuteOutputs(new_output_order); 97 // skip trip count and cond 98 node_->permuteInputs(adjustIndices(2, new_output_order)); 99 auto adjusted_block_order = adjustIndices(1, new_output_order); 100 bodyBlock()->permuteOutputs(adjusted_block_order); 101 bodyBlock()->permuteInputs(adjusted_block_order); 102 } 103 replaceMaxTripCountLoopView104 void replaceMaxTripCount(Value* new_max_trip_count) { 105 node_->replaceInput(0, new_max_trip_count); 106 } replaceInputConditionLoopView107 void replaceInputCondition(Value* new_input_condition) { 108 node_->replaceInput(1, new_input_condition); 109 } 110 111 // our way of encoding loops makes them difficult to turn back into python 112 // syntax. we have to check properties of the condition and trip count inputs 113 // to figure out which one it initially was. ModifiedLoops are not directly 114 // mappable to either For or While 115 enum LoopType { While, For, ModifiedLoop }; 116 loopTypeLoopView117 LoopType loopType() { 118 auto trip_count = toIValue(maxTripCount()); 119 auto cond_input = toIValue(inputCond()); 120 auto cond_next = toIValue(nextCond()); 121 122 bool condition_is_always_true = 123 cond_input && cond_input->toBool() && cond_next && cond_next->toBool(); 124 bool trip_count_is_specified = !trip_count || // trip is not a constant 125 trip_count->toInt() != 126 std::numeric_limits<int64_t>::max() || // it is a constant but not 127 // the default one 128 !currentTripCount() 129 ->uses() 130 .empty(); // it is actually being used in the body. 131 132 if (condition_is_always_true) { 133 // if the trip count was not specified this was a user-written while True: 134 return trip_count_is_specified ? For : While; 135 } else { 136 if (trip_count_is_specified) { 137 return ModifiedLoop; 138 } 139 return While; 140 } 141 } 142 143 private: 144 Node* node_; 145 146 // adjust index_ordering by adding indices 0 - thorugh adjust, and 147 // incrementing all existing inputs by adjust adjustIndicesLoopView148 static std::vector<size_t> adjustIndices( 149 size_t adjust, 150 const std::vector<size_t>& index_ordering) { 151 std::vector<size_t> adjusted; 152 adjusted.reserve(adjust + index_ordering.size()); 153 for (const auto i : c10::irange(adjust)) { 154 adjusted.push_back(i); 155 } 156 for (auto index : index_ordering) { 157 adjusted.push_back(index + adjust); 158 } 159 return adjusted; 160 } 161 }; 162 } // namespace torch::jit 163