xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/value_refinement_utils.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <c10/util/irange.h>
2 #include <torch/csrc/jit/passes/value_refinement_utils.h>
3 
4 namespace torch::jit {
5 
6 // [value refinement algorithm]
7 
8 // When a comparison like `cond = len(x) == 4` or `cond = len(x) != 4` is made,
9 // `cond` value carries information (refinements) about the len of `x`.
10 // When `cond` is used as the conditional of an if statement, the information
11 // it carries for its true value can be inserted into the true block
12 // and the same for its false value.
13 // For something like `y = len(x) if len(x) == 1 else 1`, in the true branch
14 // we can replace len(x) with 1 because the true refinements from `len(x) == 1`
15 // will be present in the true block.
16 // Additionally, we can optimize something like:
17 // if len(x) != 4:
18 //    raise Exception(...)
19 // return len(x)
20 // Because the true block always throws, whatever refinements exist in the false
21 // block become present in the owning block of the if node. We can also merge
22 // refinements carried by two different booleans across an if node join by
23 // taking the intersections of their refinements.
24 // if cond:
25 //    z = len(x) == 4 and len(y) == 5
26 // else:
27 //    z = len(x) == 4
28 // Here, z's true value will refine the len(x) to 4, but not len(y).
29 // If the code was written as:
30 // if cond:
31 //    z = len(x) == 4 and len(y) == 5
32 // else:
33 //    z = False
34 //
35 // Then z's true value would refine x and y, because if z is true it had to have
36 // come from the true block. Code that is written with `and` or `or` will
37 // desugar to something similar. Additionally, any True refinements that were
38 // present on `cond` can also be associated with the if node True output value.
39 
40 // The intersection of the refinements is the Value* which are in both
41 // refinements and are refined to the same length
42 // in an example like:
43 // if cond:
44 //    x = len(a) == 4 and len(b) == 5
45 // else:
46 //    x = len(a) == 4
47 // For the x output of the node we take the intersection between
48 // the refinements stored on each block output, which will result
49 // in only the refinement of len(a) == 4
intersectRefinements(const ListRefinement & ref1,const ListRefinement & ref2)50 ListRefinement intersectRefinements(
51     const ListRefinement& ref1,
52     const ListRefinement& ref2) {
53   ListRefinement out;
54   for (const auto& pair : ref1) {
55     auto val2 = ref2.find(pair.first);
56     if (val2 != ref2.end() && val2->second == pair.second) {
57       out[pair.first] = pair.second;
58     }
59   }
60   return out;
61 }
62 
63 // To union, just take all refinements from both inputs. We do not need to worry
64 // about len refinements disagreeing because a path like `if len(x) == 4 and
65 // len(x) == 5` will never be taken
66 // in an example like:
67 // if len(a) == 5:
68 //     x = len(b) == 4
69 // else:
70 //     x = False
71 // For the output x Value, if is true then the refinements present in the true
72 // block must also be true, so we take the union of `len(a) == 5` and len(b) ==
73 // 4` and assign them to true refinements of the output x value. This is a very
74 // common pattern in desugaring of `and` or `or` boolean expressions
unionRefinements(const ListRefinement & ref1,const ListRefinement & ref2)75 ListRefinement unionRefinements(
76     const ListRefinement& ref1,
77     const ListRefinement& ref2) {
78   ListRefinement out = ref1;
79   out.insert(ref2.begin(), ref2.end());
80   return out;
81 }
82 
joinIfRefinements(Node * if_node,std::unordered_set<Block * > & throwing_blocks,ListRefinement & curr_block_refinements,ListRefinement & true_block_refinements,ListRefinement & false_block_refinements,std::unordered_map<Value *,BooleanRefinementMapping> & boolean_value_refinements)83 void joinIfRefinements(
84     Node* if_node,
85     std::unordered_set<Block*>& throwing_blocks,
86     ListRefinement& curr_block_refinements,
87     ListRefinement& true_block_refinements,
88     ListRefinement& false_block_refinements,
89     std::unordered_map<Value*, BooleanRefinementMapping>&
90         boolean_value_refinements) {
91   IfView if_n(if_node);
92   Block* b = if_node->owningBlock();
93 
94   bool true_block_throws = throwing_blocks.count(if_n.thenBlock());
95   bool false_block_throws = throwing_blocks.count(if_n.elseBlock());
96 
97   // if one block throws, the refinements for the other block
98   // become present in the current block, and all bool outputs
99   // of the if node take their refinements from non throwing block
100   // output
101 
102   if (true_block_throws || false_block_throws) {
103     if (true_block_throws && false_block_throws) {
104       throwing_blocks.insert(b);
105       return;
106     }
107     if (true_block_throws) {
108       curr_block_refinements.insert(
109           false_block_refinements.begin(), false_block_refinements.end());
110     } else {
111       curr_block_refinements.insert(
112           true_block_refinements.begin(), true_block_refinements.end());
113     }
114     Block* non_throwing_block =
115         true_block_throws ? if_node->blocks().at(1) : if_node->blocks().at(0);
116     for (const auto i : c10::irange(if_n.outputs().size())) {
117       if (boolean_value_refinements.count(
118               non_throwing_block->outputs().at(i))) {
119         boolean_value_refinements[if_node->outputs().at(i)] =
120             boolean_value_refinements[non_throwing_block->outputs().at(i)];
121       }
122     }
123     return;
124   }
125 
126   for (const auto i : c10::irange(if_n.outputs().size())) {
127     if (!(if_n.outputs().at(i)->type() == BoolType::get())) {
128       return;
129     }
130     Value* true_v = if_n.thenOutputs().at(i);
131     Value* false_v = if_n.elseOutputs().at(i);
132 
133     if (!boolean_value_refinements.count(true_v) &&
134         !boolean_value_refinements.count(false_v) &&
135         !constant_as<bool>(true_v) && !constant_as<bool>(false_v)) {
136       return;
137     }
138 
139     // if either block has a constant bool output, e.g. `true` on the
140     // true block, then for the `false` value we can take the false
141     // refinements present on the false block and from the other block
142     // output value bc if the output is false it had to have come from the
143     // false block. if len(a) == 5:
144     //     x = len(b) == 4
145     // else:
146     //     x = False
147     // if x is true, then we know both len(a) == 5 and len(b) == 4
148     //
149     // if neither block has a constant bool value, we just take the
150     // intersection of the refinements from boolean outputs.
151     // if cond:
152     //    x = len(a) == 4 and len(b) == 5
153     // else:
154     //    x = len(a) == 4
155     // here, we know if x is true, then len(a) == 4, but not len(b)
156     // == 5, because that refinement is not present in the true block.
157     // TODO: could also take intersection of refinements present in
158     // both blocks, but it's not a real use case.
159 
160     // boolean_value_refinements[value] is safe to access because
161     // BooleanRefinementMapping has a default constructor
162 
163     BooleanRefinementMapping out;
164     if (auto maybe_bool = constant_as<bool>(true_v)) {
165       if (*maybe_bool) {
166         out = BooleanRefinementMapping::FalseRefinements(unionRefinements(
167             boolean_value_refinements[false_v].false_refine(),
168             false_block_refinements));
169       } else {
170         out = BooleanRefinementMapping::TrueRefinements(unionRefinements(
171             boolean_value_refinements[false_v].true_refine(),
172             false_block_refinements));
173       }
174     } else if (auto maybe_bool = constant_as<bool>(false_v)) {
175       if (*maybe_bool) {
176         out = BooleanRefinementMapping::FalseRefinements(unionRefinements(
177             boolean_value_refinements[true_v].false_refine(),
178             true_block_refinements));
179       } else {
180         out = BooleanRefinementMapping::TrueRefinements(unionRefinements(
181             boolean_value_refinements[true_v].true_refine(),
182             true_block_refinements));
183       }
184     } else if (
185         boolean_value_refinements.count(true_v) &&
186         boolean_value_refinements.count(false_v)) {
187       out = boolean_value_refinements[true_v].intersectBooleanRefinementMapping(
188           boolean_value_refinements[false_v]);
189     }
190     boolean_value_refinements[if_n.outputs().at(i)] = out;
191   }
192 }
193 
handleCommonRefinentOperators(Node * n,std::unordered_set<Block * > & throwing_blocks,std::unordered_map<Value *,BooleanRefinementMapping> & info)194 bool handleCommonRefinentOperators(
195     Node* n,
196     std::unordered_set<Block*>& throwing_blocks,
197     std::unordered_map<Value*, BooleanRefinementMapping>& info) {
198   if (n->kind() == prim::RaiseException) {
199     throwing_blocks.insert(n->owningBlock());
200     return true;
201   }
202   if (n->kind() == aten::__not__ &&
203       n->inputs().at(0)->type()->cast<BoolType>()) {
204     // __not__(inp) -> reverse refinements
205     if (info.count(n->input())) {
206       auto& input_ref = info[n->input()];
207       info[n->output()] = BooleanRefinementMapping(
208           input_ref.false_refine(), input_ref.true_refine());
209     }
210     return true;
211   }
212   if (n->matches("aten::eq(bool a, bool b) -> bool") ||
213       (n->matches("aten::ne(bool a, bool b) -> bool"))) {
214     for (size_t const_index : {0, 1}) {
215       if (n->input(const_index)->node()->kind() != prim::Constant) {
216         continue;
217       }
218       auto const_input = constant_as<bool>(n->input(const_index)).value();
219       auto non_const_input = n->input(1 - const_index);
220       if (!info.count(non_const_input)) {
221         continue;
222       }
223       // value == False / value != True -> equivalent to __not__ value
224       // value == True / value != False -> equivalent to value
225       auto& input_ref = info[non_const_input];
226       if ((!const_input && n->kind() == aten::eq) ||
227           (const_input && n->kind() == aten::ne)) {
228         info[n->output()] = BooleanRefinementMapping(
229             input_ref.false_refine(), input_ref.true_refine());
230       } else {
231         info[n->output()] = BooleanRefinementMapping(
232             input_ref.true_refine(), input_ref.false_refine());
233       }
234     }
235     return true;
236   }
237   return false;
238 }
239 
240 } // namespace torch::jit
241