1 #pragma once 2 3 #include <ATen/core/jit_type.h> 4 #include <torch/csrc/jit/ir/alias_analysis.h> 5 #include <torch/csrc/jit/ir/ir_views.h> 6 #include <torch/csrc/jit/jit_log.h> 7 #include <torch/csrc/jit/passes/dead_code_elimination.h> 8 #include <torch/csrc/jit/passes/peephole.h> 9 #include <torch/csrc/jit/passes/peephole_list_idioms.h> 10 #include <torch/csrc/jit/runtime/graph_executor.h> 11 12 namespace torch::jit { 13 14 // Refine from Value of type List -> len of list 15 // If a refinement mapping of List Value * -> len is present in a block 16 // the list is guaranteed to be that length 17 // TODO: vector may be faster 18 using ListRefinement = std::unordered_map<Value*, int64_t>; 19 20 TORCH_API ListRefinement 21 intersectRefinements(const ListRefinement& ref1, const ListRefinement& ref2); 22 23 TORCH_API ListRefinement 24 unionRefinements(const ListRefinement& ref1, const ListRefinement& ref2); 25 26 // Represents the refinement information that can be carried on a boolean 27 struct BooleanRefinementMapping { BooleanRefinementMappingBooleanRefinementMapping28 BooleanRefinementMapping( 29 ListRefinement true_refine, 30 ListRefinement false_refine) 31 : true_refine_(std::move(true_refine)), 32 false_refine_(std::move(false_refine)){}; 33 BooleanRefinementMapping() = default; // empty 34 FalseRefinementsBooleanRefinementMapping35 static BooleanRefinementMapping FalseRefinements( 36 ListRefinement false_refine) { 37 return BooleanRefinementMapping({}, std::move(false_refine)); 38 } 39 TrueRefinementsBooleanRefinementMapping40 static BooleanRefinementMapping TrueRefinements(ListRefinement true_refine) { 41 return BooleanRefinementMapping(std::move(true_refine), {}); 42 } 43 intersectBooleanRefinementMappingBooleanRefinementMapping44 BooleanRefinementMapping intersectBooleanRefinementMapping( 45 BooleanRefinementMapping& other) { 46 return BooleanRefinementMapping( 47 intersectRefinements(true_refine_, other.true_refine()), 48 intersectRefinements(false_refine_, other.false_refine())); 49 } 50 true_refineBooleanRefinementMapping51 ListRefinement& true_refine() { 52 return true_refine_; 53 } 54 false_refineBooleanRefinementMapping55 ListRefinement& false_refine() { 56 return false_refine_; 57 } 58 59 private: 60 ListRefinement true_refine_; 61 ListRefinement false_refine_; 62 }; 63 64 TORCH_API void joinIfRefinements( 65 Node* if_node, 66 std::unordered_set<Block*>& throwing_blocks, 67 ListRefinement& curr_block_refinements, 68 ListRefinement& true_block_refinements, 69 ListRefinement& false_block_refinements, 70 std::unordered_map<Value*, BooleanRefinementMapping>& info); 71 72 // handles adding blocks to throwing blocks and propagating refinements via 73 // boolean comparisons 74 TORCH_API bool handleCommonRefinentOperators( 75 Node* n, 76 std::unordered_set<Block*>& throwing_blocks, 77 std::unordered_map<Value*, BooleanRefinementMapping>& info); 78 79 } // namespace torch::jit 80