xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/value_refinement_utils.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/core/jit_type.h>
4 #include <torch/csrc/jit/ir/alias_analysis.h>
5 #include <torch/csrc/jit/ir/ir_views.h>
6 #include <torch/csrc/jit/jit_log.h>
7 #include <torch/csrc/jit/passes/dead_code_elimination.h>
8 #include <torch/csrc/jit/passes/peephole.h>
9 #include <torch/csrc/jit/passes/peephole_list_idioms.h>
10 #include <torch/csrc/jit/runtime/graph_executor.h>
11 
12 namespace torch::jit {
13 
14 // Refine from Value of type List -> len of list
15 // If a refinement mapping of List Value * -> len is present in a block
16 // the list is guaranteed to be that length
17 // TODO: vector may be faster
18 using ListRefinement = std::unordered_map<Value*, int64_t>;
19 
20 TORCH_API ListRefinement
21 intersectRefinements(const ListRefinement& ref1, const ListRefinement& ref2);
22 
23 TORCH_API ListRefinement
24 unionRefinements(const ListRefinement& ref1, const ListRefinement& ref2);
25 
26 // Represents the refinement information that can be carried on a boolean
27 struct BooleanRefinementMapping {
BooleanRefinementMappingBooleanRefinementMapping28   BooleanRefinementMapping(
29       ListRefinement true_refine,
30       ListRefinement false_refine)
31       : true_refine_(std::move(true_refine)),
32         false_refine_(std::move(false_refine)){};
33   BooleanRefinementMapping() = default; // empty
34 
FalseRefinementsBooleanRefinementMapping35   static BooleanRefinementMapping FalseRefinements(
36       ListRefinement false_refine) {
37     return BooleanRefinementMapping({}, std::move(false_refine));
38   }
39 
TrueRefinementsBooleanRefinementMapping40   static BooleanRefinementMapping TrueRefinements(ListRefinement true_refine) {
41     return BooleanRefinementMapping(std::move(true_refine), {});
42   }
43 
intersectBooleanRefinementMappingBooleanRefinementMapping44   BooleanRefinementMapping intersectBooleanRefinementMapping(
45       BooleanRefinementMapping& other) {
46     return BooleanRefinementMapping(
47         intersectRefinements(true_refine_, other.true_refine()),
48         intersectRefinements(false_refine_, other.false_refine()));
49   }
50 
true_refineBooleanRefinementMapping51   ListRefinement& true_refine() {
52     return true_refine_;
53   }
54 
false_refineBooleanRefinementMapping55   ListRefinement& false_refine() {
56     return false_refine_;
57   }
58 
59  private:
60   ListRefinement true_refine_;
61   ListRefinement false_refine_;
62 };
63 
64 TORCH_API void joinIfRefinements(
65     Node* if_node,
66     std::unordered_set<Block*>& throwing_blocks,
67     ListRefinement& curr_block_refinements,
68     ListRefinement& true_block_refinements,
69     ListRefinement& false_block_refinements,
70     std::unordered_map<Value*, BooleanRefinementMapping>& info);
71 
72 // handles adding blocks to throwing blocks and propagating refinements via
73 // boolean comparisons
74 TORCH_API bool handleCommonRefinentOperators(
75     Node* n,
76     std::unordered_set<Block*>& throwing_blocks,
77     std::unordered_map<Value*, BooleanRefinementMapping>& info);
78 
79 } // namespace torch::jit
80