1# mypy: allow-untyped-defs 2import math 3import operator 4import traceback 5from functools import partial 6from typing import Callable, Dict, List, NamedTuple, Set 7 8import sympy 9 10import torch 11import torch.fx 12from torch.utils._sympy.value_ranges import ValueRanges 13from torch.utils._sympy.numbers import int_oo 14from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols 15from torch.fx.passes.infra.pass_base import PassBase, PassResult 16 17__all__ = ["InputDim"] 18 19 20class InputDim(NamedTuple): 21 input_name: str 22 dim: int 23 24 25def _convert_to_int(val): 26 # Convert simple sympy Integers into concrete int 27 if val in (sympy.oo, int_oo): 28 return math.inf 29 if val in (-sympy.oo, -int_oo): 30 return -math.inf 31 if isinstance(val, sympy.Integer): 32 return int(val) 33 raise RuntimeError( 34 "Export constraints cannot be non-integer expressions" 35 ) 36 37 38def _convert_range_to_int(range: ValueRanges): 39 assert isinstance(range, ValueRanges) 40 min_val = _convert_to_int(range.lower) 41 max_val = _convert_to_int(range.upper) 42 return min_val, max_val 43 44 45class _AddRuntimeAssertionsForInlineConstraintsPass(PassBase): 46 def __init__( 47 self, 48 range_constraints: Dict[sympy.Symbol, ValueRanges], 49 ): 50 super().__init__() 51 self.range_constraints: Dict[sympy.Symbol, ValueRanges] = range_constraints 52 self._asserts_generated_unbacked_symbols: Set[sympy.Symbol] = set() 53 self.counter = 0 54 55 def _assert_range_constraint(self, node, lower, upper, assert_msg): 56 last_node = node 57 if lower > -math.inf: 58 last_node = self._insert_assert_async(last_node, operator.ge, node, lower, assert_msg) 59 60 if upper < math.inf: 61 last_node = self._insert_assert_async(last_node, operator.le, node, upper, assert_msg) 62 63 def _insert_assert_async(self, last_node, op, lower, upper, assert_msg): 64 """ 65 Inserts assert_async call_function nodes in the graph. This function is 66 called **during** the interpreter-based pass. 67 """ 68 self.counter += 1 69 graph = last_node.graph 70 with graph.inserting_after(last_node): 71 cmp = graph.call_function(op, (lower, upper), {}) 72 with graph.inserting_after(cmp): 73 cmp_tensor = graph.call_function(torch.ops.aten.scalar_tensor.default, (cmp,), {}) 74 with graph.inserting_after(cmp_tensor): 75 assert_async = graph.call_function( 76 torch.ops.aten._assert_async.msg, 77 (cmp_tensor, assert_msg), 78 {}, 79 ) 80 return assert_async 81 82 def call(self, graph_module) -> PassResult: 83 self.existing_inline_assertions = _get_existing_inline_assertions( 84 graph_module, self.range_constraints 85 ) 86 87 for module in graph_module.modules(): 88 if not isinstance(module, torch.fx.GraphModule): 89 continue 90 for node in module.graph.nodes: 91 if node.op != "call_function": 92 continue 93 if "val" not in node.meta: 94 continue 95 96 val = node.meta["val"] 97 # In general, we may have to deal the case such as: ret[1].shape[0]. 98 # We need first find out what symbols require assertion, then we need to follow the path 99 # from ret to the symbol, construct the proxies along the way and construct the messages 100 # piece-wise at the same time. 101 # 102 # We use post-order traversal to collect all the proxies callbacks needed, construct 103 # the error message callbacks, and at the top-level traversal tree we execute all the callbacks. 104 # We need the callbacks because, in order to call the function to create a proxy for shape[0], we 105 # need the proxy for shape, which further requires the proxy for ret[1], etc. 106 107 def add_assertions(val): 108 call_backs: List[Callable] = [] 109 messages: List[str] = [] 110 if isinstance(val, (torch.SymInt, torch.SymFloat, torch.SymBool)): 111 symbol = val.node.expr 112 if symbol in self.existing_inline_assertions: 113 return call_backs, messages 114 if isinstance(symbol, sympy.Symbol) and free_unbacked_symbols(symbol): 115 if symbol in self._asserts_generated_unbacked_symbols: 116 return call_backs, messages 117 # We only care about unbacked symints for these inline 118 # constraints, which are prefixed with 'u' 119 constraint = self.range_constraints[symbol] 120 min_val, max_val = _convert_range_to_int(constraint) 121 assert_msg = f" is outside of inline constraint [{min_val}, {max_val}]." 122 call_backs.append( 123 partial(self._assert_range_constraint, lower=min_val, upper=max_val) 124 ) 125 messages.append(assert_msg) 126 self._asserts_generated_unbacked_symbols.add(symbol) 127 128 elif isinstance(val, torch.Tensor): 129 for i, sym in enumerate(val.shape): 130 cbs, msgs = add_assertions(sym) 131 for cb, msg in zip(cbs, msgs): 132 def sym_size_cb(node, assert_msg, dim): 133 with node.graph.inserting_after(node): 134 dim_node = module.graph.call_function( 135 torch.ops.aten.sym_size.int, 136 (node, dim), 137 {}, 138 ) 139 cb(node=dim_node, assert_msg=assert_msg) 140 call_backs.append(partial(sym_size_cb, dim=i)) 141 messages.append(f".shape[{i}]" + msg) 142 return call_backs, messages 143 144 callbacks, messages = add_assertions(val) 145 for cb, msg in zip(callbacks, messages): 146 cb(node=node, assert_msg=f"{node}" + msg) 147 148 module.recompile() 149 150 # Sometimes this pass would return a wrong graph where we have mismatched 151 # node names in signature. Before we fix it, let's just skip it. 152 if self.counter == 0 and type(self) is _AddRuntimeAssertionsForInlineConstraintsPass: 153 return PassResult(graph_module, False) 154 155 # Populate the stack trace with dummy vals to respect IR 156 for node in graph_module.graph.nodes: 157 if not node.meta.get("stack_trace", None) and node.op not in ["placeholder", "output"]: 158 node.meta["stack_trace"] = "".join(traceback.format_stack(limit=1)) 159 return PassResult(graph_module, True) 160 161 162def _get_existing_inline_assertions( 163 graph_module: torch.fx.GraphModule, 164 range_constraints: Dict[sympy.Symbol, ValueRanges], 165) -> Dict[sympy.Symbol, ValueRanges]: 166 existing_inline_assertions: Dict[sympy.Symbol, ValueRanges] = {} 167 168 for module in graph_module.modules(): 169 if not isinstance(module, torch.fx.GraphModule): 170 continue 171 172 # Find all the existing inline assertions. They will look something like: 173 # %_local_scalar_dense = call_function[target=torch.ops.aten._local_scalar_dense.default](args = (%arg1_1,), kwargs = {}) 174 # %ge = call_function[target=operator.ge](args = (%_local_scalar_dense, 0), kwargs = {}) 175 # %_assert_scalar = call_function[target=torch.ops.aten._assert_scalar.default](args = (%scalar_tensor, "..."), kwargs = {}) 176 for node in module.graph.nodes: 177 if node.target != torch.ops.aten._assert_scalar.default: 178 continue 179 180 compare_arg = node.args[0] 181 if not ( 182 isinstance(compare_arg, torch.fx.Node) and 183 compare_arg.op == "call_function" and 184 compare_arg.target in (operator.le, operator.ge) and 185 len(compare_arg.args) == 2 186 ): 187 continue 188 189 compare_op = compare_arg.target 190 lhs, rhs = compare_arg.args 191 192 def maybe_get_symint(x): 193 if ( 194 isinstance(x, torch.fx.Node) and 195 "val" in x.meta and 196 isinstance(x.meta["val"], torch.SymInt) 197 ): 198 return x.meta["val"].node.expr 199 return x 200 201 lhs = maybe_get_symint(lhs) 202 rhs = maybe_get_symint(rhs) 203 204 if compare_op == operator.ge: 205 lhs, rhs = rhs, lhs 206 207 if isinstance(lhs, sympy.Symbol) and isinstance(rhs, int): 208 symint = lhs 209 scalar = rhs 210 elif isinstance(rhs, sympy.Symbol) and isinstance(lhs, int): 211 symint = rhs 212 scalar = lhs 213 else: 214 continue 215 216 if symint not in range_constraints: 217 raise RuntimeError(f"Unable to find symint {symint} in {range_constraints}") 218 219 previous_range = existing_inline_assertions.get(symint, ValueRanges(-math.inf, math.inf)) 220 221 if symint is lhs: 222 bounds = ValueRanges(-math.inf, scalar) 223 else: 224 bounds = ValueRanges(scalar, math.inf) 225 existing_inline_assertions[symint] = previous_range & bounds 226 227 return existing_inline_assertions 228