xref: /aosp_15_r20/external/pytorch/torch/_export/passes/add_runtime_assertions_for_constraints_pass.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import math
3import operator
4import traceback
5from functools import partial
6from typing import Callable, Dict, List, NamedTuple, Set
7
8import sympy
9
10import torch
11import torch.fx
12from torch.utils._sympy.value_ranges import ValueRanges
13from torch.utils._sympy.numbers import int_oo
14from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols
15from torch.fx.passes.infra.pass_base import PassBase, PassResult
16
17__all__ = ["InputDim"]
18
19
20class InputDim(NamedTuple):
21    input_name: str
22    dim: int
23
24
25def _convert_to_int(val):
26    # Convert simple sympy Integers into concrete int
27    if val in (sympy.oo, int_oo):
28        return math.inf
29    if val in (-sympy.oo, -int_oo):
30        return -math.inf
31    if isinstance(val, sympy.Integer):
32        return int(val)
33    raise RuntimeError(
34        "Export constraints cannot be non-integer expressions"
35    )
36
37
38def _convert_range_to_int(range: ValueRanges):
39    assert isinstance(range, ValueRanges)
40    min_val = _convert_to_int(range.lower)
41    max_val = _convert_to_int(range.upper)
42    return min_val, max_val
43
44
45class _AddRuntimeAssertionsForInlineConstraintsPass(PassBase):
46    def __init__(
47        self,
48        range_constraints: Dict[sympy.Symbol, ValueRanges],
49    ):
50        super().__init__()
51        self.range_constraints: Dict[sympy.Symbol, ValueRanges] = range_constraints
52        self._asserts_generated_unbacked_symbols: Set[sympy.Symbol] = set()
53        self.counter = 0
54
55    def _assert_range_constraint(self, node, lower, upper, assert_msg):
56        last_node = node
57        if lower > -math.inf:
58            last_node = self._insert_assert_async(last_node, operator.ge, node, lower, assert_msg)
59
60        if upper < math.inf:
61            last_node = self._insert_assert_async(last_node, operator.le, node, upper, assert_msg)
62
63    def _insert_assert_async(self, last_node, op, lower, upper, assert_msg):
64        """
65        Inserts assert_async call_function nodes in the graph. This function is
66        called **during** the interpreter-based pass.
67        """
68        self.counter += 1
69        graph = last_node.graph
70        with graph.inserting_after(last_node):
71            cmp = graph.call_function(op, (lower, upper), {})
72        with graph.inserting_after(cmp):
73            cmp_tensor = graph.call_function(torch.ops.aten.scalar_tensor.default, (cmp,), {})
74        with graph.inserting_after(cmp_tensor):
75            assert_async = graph.call_function(
76                torch.ops.aten._assert_async.msg,
77                (cmp_tensor, assert_msg),
78                {},
79            )
80        return assert_async
81
82    def call(self, graph_module) -> PassResult:
83        self.existing_inline_assertions = _get_existing_inline_assertions(
84            graph_module, self.range_constraints
85        )
86
87        for module in graph_module.modules():
88            if not isinstance(module, torch.fx.GraphModule):
89                continue
90            for node in module.graph.nodes:
91                if node.op != "call_function":
92                    continue
93                if "val" not in node.meta:
94                    continue
95
96                val = node.meta["val"]
97                # In general, we may have to deal the case such as: ret[1].shape[0].
98                # We need first find out what symbols require assertion, then we need to follow the path
99                # from ret to the symbol, construct the proxies along the way and construct the messages
100                # piece-wise at the same time.
101                #
102                # We use post-order traversal to collect all the proxies callbacks needed, construct
103                # the error message callbacks, and at the top-level traversal tree we execute all the callbacks.
104                # We need the callbacks because, in order to call the function to create a proxy for shape[0], we
105                # need the proxy for shape, which further requires the proxy for ret[1], etc.
106
107                def add_assertions(val):
108                    call_backs: List[Callable] = []
109                    messages: List[str] = []
110                    if isinstance(val, (torch.SymInt, torch.SymFloat, torch.SymBool)):
111                        symbol = val.node.expr
112                        if symbol in self.existing_inline_assertions:
113                            return call_backs, messages
114                        if isinstance(symbol, sympy.Symbol) and free_unbacked_symbols(symbol):
115                            if symbol in self._asserts_generated_unbacked_symbols:
116                                return call_backs, messages
117                            # We only care about unbacked symints for these inline
118                            # constraints, which are prefixed with 'u'
119                            constraint = self.range_constraints[symbol]
120                            min_val, max_val = _convert_range_to_int(constraint)
121                            assert_msg = f" is outside of inline constraint [{min_val}, {max_val}]."
122                            call_backs.append(
123                                partial(self._assert_range_constraint, lower=min_val, upper=max_val)
124                            )
125                            messages.append(assert_msg)
126                            self._asserts_generated_unbacked_symbols.add(symbol)
127
128                    elif isinstance(val, torch.Tensor):
129                        for i, sym in enumerate(val.shape):
130                            cbs, msgs = add_assertions(sym)
131                            for cb, msg in zip(cbs, msgs):
132                                def sym_size_cb(node, assert_msg, dim):
133                                    with node.graph.inserting_after(node):
134                                        dim_node = module.graph.call_function(
135                                            torch.ops.aten.sym_size.int,
136                                            (node, dim),
137                                            {},
138                                        )
139                                    cb(node=dim_node, assert_msg=assert_msg)
140                                call_backs.append(partial(sym_size_cb, dim=i))
141                                messages.append(f".shape[{i}]" + msg)
142                    return call_backs, messages
143
144                callbacks, messages = add_assertions(val)
145                for cb, msg in zip(callbacks, messages):
146                    cb(node=node, assert_msg=f"{node}" + msg)
147
148            module.recompile()
149
150        # Sometimes this pass would return a wrong graph where we have mismatched
151        # node names in signature. Before we fix it, let's just skip it.
152        if self.counter == 0 and type(self) is _AddRuntimeAssertionsForInlineConstraintsPass:
153            return PassResult(graph_module, False)
154
155        # Populate the stack trace with dummy vals to respect IR
156        for node in graph_module.graph.nodes:
157            if not node.meta.get("stack_trace", None) and node.op not in ["placeholder", "output"]:
158                node.meta["stack_trace"] = "".join(traceback.format_stack(limit=1))
159        return PassResult(graph_module, True)
160
161
162def _get_existing_inline_assertions(
163    graph_module: torch.fx.GraphModule,
164    range_constraints: Dict[sympy.Symbol, ValueRanges],
165) -> Dict[sympy.Symbol, ValueRanges]:
166    existing_inline_assertions: Dict[sympy.Symbol, ValueRanges] = {}
167
168    for module in graph_module.modules():
169        if not isinstance(module, torch.fx.GraphModule):
170            continue
171
172        # Find all the existing inline assertions. They will look something like:
173        # %_local_scalar_dense = call_function[target=torch.ops.aten._local_scalar_dense.default](args = (%arg1_1,), kwargs = {})
174        # %ge = call_function[target=operator.ge](args = (%_local_scalar_dense, 0), kwargs = {})
175        # %_assert_scalar = call_function[target=torch.ops.aten._assert_scalar.default](args = (%scalar_tensor, "..."), kwargs = {})
176        for node in module.graph.nodes:
177            if node.target != torch.ops.aten._assert_scalar.default:
178                continue
179
180            compare_arg = node.args[0]
181            if not (
182                isinstance(compare_arg, torch.fx.Node) and
183                compare_arg.op == "call_function" and
184                compare_arg.target in (operator.le, operator.ge) and
185                len(compare_arg.args) == 2
186            ):
187                continue
188
189            compare_op = compare_arg.target
190            lhs, rhs = compare_arg.args
191
192            def maybe_get_symint(x):
193                if (
194                    isinstance(x, torch.fx.Node) and
195                    "val" in x.meta and
196                    isinstance(x.meta["val"], torch.SymInt)
197                ):
198                    return x.meta["val"].node.expr
199                return x
200
201            lhs = maybe_get_symint(lhs)
202            rhs = maybe_get_symint(rhs)
203
204            if compare_op == operator.ge:
205                lhs, rhs = rhs, lhs
206
207            if isinstance(lhs, sympy.Symbol) and isinstance(rhs, int):
208                symint = lhs
209                scalar = rhs
210            elif isinstance(rhs, sympy.Symbol) and isinstance(lhs, int):
211                symint = rhs
212                scalar = lhs
213            else:
214                continue
215
216            if symint not in range_constraints:
217                raise RuntimeError(f"Unable to find symint {symint} in {range_constraints}")
218
219            previous_range = existing_inline_assertions.get(symint, ValueRanges(-math.inf, math.inf))
220
221            if symint is lhs:
222                bounds = ValueRanges(-math.inf, scalar)
223            else:
224                bounds = ValueRanges(scalar, math.inf)
225            existing_inline_assertions[symint] = previous_range & bounds
226
227    return existing_inline_assertions
228