xref: /aosp_15_r20/external/pytorch/torch/_export/verifier.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import inspect
3import math
4import operator
5from collections.abc import Iterable
6from typing import Any, Dict, final, List, Tuple, Type, TYPE_CHECKING
7
8import torch
9from torch._ops import HigherOrderOperator, OpOverload
10from torch._subclasses.fake_tensor import FakeTensor
11from torch.export.graph_signature import (
12    CustomObjArgument,
13    InputKind,
14    SymIntArgument,
15    TensorArgument,
16    TokenArgument,
17)
18from torch.fx import GraphModule
19
20if TYPE_CHECKING:
21    from torch.export.exported_program import ExportedProgram
22
23class SpecViolationError(Exception):
24    pass
25
26
27def is_functional(op: OpOverload) -> bool:
28    return not op._schema.is_mutable
29
30
31def _check_has_fake_tensor(node: torch.fx.Node) -> None:
32    # TODO(angelayi): remove this in favor of _check_val
33    return _check_val(node)
34
35
36def _check_val(node: torch.fx.Node) -> None:
37    from torch.fx.experimental.symbolic_shapes import SymBool, SymFloat, SymInt
38
39    def _check_correct_val(val):
40        if val is None:
41            return True
42        elif isinstance(val, (int, bool, str, float)):
43            return True
44        elif isinstance(val, (torch.memory_format, torch.dtype, torch.device, torch.layout)):
45            return True
46        elif isinstance(val, (FakeTensor, torch.Tensor)):  # TODO(zhxchen17) Remove Tensor.
47            return True
48        elif isinstance(val, (SymInt, SymFloat, SymBool)):
49            return True
50        elif isinstance(val, CustomObjArgument):
51            return True
52        elif isinstance(val, Iterable):
53            return all(_check_correct_val(x) for x in val)
54        return False
55
56    def _no_returns(op):
57        if not isinstance(op, OpOverload):
58            return False
59        return len(op._schema.returns) == 0
60
61    if "val" not in node.meta:
62        if node.op == "call_function" and _no_returns(node.target):
63            return
64        raise SpecViolationError(f"Node.meta {node.name} is missing val field.")
65
66    val = node.meta["val"]
67    if not _check_correct_val(val):
68        raise SpecViolationError(f"Node.meta {node.name} has invalid val field {val}")
69
70
71def _check_torch_fn(node: torch.fx.Node) -> None:
72    torch_fn = node.meta.get("torch_fn")
73    if torch_fn is None:
74        raise SpecViolationError(f"Unable to find torch_fn metadata for node {node.name}")
75    if (
76        not isinstance(torch_fn, tuple) and
77        isinstance(torch_fn[0], str) and
78        isinstance(torch_fn[1], str)
79    ):
80        raise SpecViolationError(f"Node.meta {node.name} has invalid torch_fn field {torch_fn}")
81
82class _VerifierMeta(type):
83    _registry: Dict[str, Type['Verifier']] = {}
84
85    def __new__(metacls, name, bases, attrs):
86        if bases:
87            if "check" in attrs or "_check_graph_module" in attrs:
88                raise SyntaxError("Overriding method check is not allowed.")
89            assert "dialect" in attrs and attrs["dialect"] != "ATEN"
90        else:
91            assert "check" in attrs
92            assert "_check_graph_module" in attrs
93            assert attrs["dialect"] == "ATEN"
94
95        assert isinstance(attrs["dialect"], str)
96        ret = type.__new__(metacls, name, bases, attrs)
97        metacls._registry[attrs["dialect"]] = ret  # type: ignore[assignment]
98        return ret
99
100def getattr_recursive(obj: Any, target: str) -> Any:
101    target_atoms = target.split('.')
102    attr_itr = obj
103    for i, atom in enumerate(target_atoms):
104        if not hasattr(attr_itr, atom):
105            raise RuntimeError(f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}")
106        attr_itr = getattr(attr_itr, atom)
107    return attr_itr
108
109
110class Verifier(metaclass=_VerifierMeta):
111    dialect = "ATEN"
112
113    def allowed_builtin_ops(self) -> List:
114        return [
115            operator.getitem,
116            operator.add,
117            operator.mul,
118            operator.sub,
119            operator.truediv,
120            operator.ge,
121            operator.le,
122            operator.gt,
123            operator.lt,
124            operator.eq,
125            operator.ne,
126            operator.floordiv,
127            operator.mod,
128            operator.and_,
129            operator.or_,
130            operator.not_,
131            operator.pow,
132            operator.neg,
133            operator.abs,
134            math.ceil,
135            math.floor,
136            math.trunc,
137        ]
138
139    def allowed_op_types(self) -> Tuple[Type[Any], ...]:
140        return (OpOverload, HigherOrderOperator)
141
142    def allowed_getattr_types(self) -> Tuple[Type[Any], ...]:
143        return (torch.fx.GraphModule,)
144
145    def check_valid_op(self, op):
146        pass
147
148    def check_additional(self, gm: GraphModule) -> None:
149        """
150        Additional checks that are specific to some dialects.
151        """
152
153    @final
154    def check(self, ep: "ExportedProgram") -> None:
155        self._check_graph_module(ep.graph_module)
156        _verify_exported_program_module_call_graph(ep)
157        _verify_exported_program_signature(ep)
158
159    @final
160    def _check_graph_module(self, gm: torch.fx.GraphModule) -> None:
161        def _allowed_getattr_types() -> Tuple[Type[Any], ...]:
162            ret = self.allowed_getattr_types()
163            assert not any(t is object for t in ret)
164            return ret
165
166        def _check_valid_op(op) -> None:
167            def _allowed_builtin_ops() -> List:
168                ret = self.allowed_builtin_ops()
169                assert all(inspect.isbuiltin(op) for op in ret)
170                return ret
171
172            def _allowed_op_types() -> Tuple[Type[Any], ...]:
173                ret = self.allowed_op_types()
174                assert not any(t is object for t in ret)
175                return ret
176
177            # TODO Remove this allowlist.
178            _allowed_torch_functions = (
179                torch.autograd.grad_mode.set_grad_enabled,
180                torch.sym_int,
181                torch.sym_float,
182                torch.sym_ite,
183                torch.sym_max,
184                torch.sym_min,
185                torch.sym_not,
186                torch.sym_sqrt,
187                # TODO (tmanlaibaatar)
188                # Predispatch export is able to contain autograd ops.
189                # These will be modeled as HOO later
190                torch._C._set_grad_enabled,
191            )
192
193            if not isinstance(op, _allowed_op_types()):
194                if op not in _allowed_builtin_ops() and op not in _allowed_torch_functions:
195                    raise SpecViolationError(
196                        f"Operator '{op}' is not an allowed operator type: {_allowed_op_types()}\n"
197                        f"Valid builtin ops: {_allowed_builtin_ops()}"
198                        f"Valid torch functions: {_allowed_torch_functions}"
199                    )
200
201            if isinstance(op, OpOverload):
202                # All ops functional
203                # TODO (tmanlaibaatar) more proper way is needed here
204                if self.dialect != "TRAINING" and not is_functional(op):
205                    raise SpecViolationError(
206                        f"operator '{op}' is not functional"
207                    )
208            self.check_valid_op(op)
209
210        for mod in gm.modules():
211            if not isinstance(mod, torch.fx.GraphModule):
212                continue
213
214            mod.graph.lint()
215            for node in mod.graph.nodes:
216                # TODO(T140410192): should have fake tensor for all dialects
217                if node.op in {"call_module", "call_method"}:
218                    raise SpecViolationError(
219                        f"call_module is not valid: got a class '{node.target}' ",
220                    )
221
222                elif node.op == "call_function":
223                    _check_val(node)
224
225                    _check_valid_op(node.target)
226
227                elif node.op == "get_attr":
228                    if not isinstance(node.target, str):
229                        raise SpecViolationError(
230                            f"Expected get_attr target to be string, but got {type(node.target)}"
231                        )
232
233                    attr = getattr_recursive(mod, node.target)
234                    if isinstance(attr, torch.nn.Module):
235                        def _is_type(name, ty):
236                            return isinstance(getattr(attr, name, None), ty)
237                        if type(attr).__name__ == "LoweredBackendModule":
238                            if _is_type("backend_id", str) \
239                                    and _is_type("processed_bytes", bytes) \
240                                    and _is_type("compile_specs", list) \
241                                    and hasattr(attr, "original_module"):
242                                continue
243                            else:
244                                backend_id = getattr(attr, "backend_id", None)
245                                processed_bytes = getattr(attr, "processed_bytes", None)
246                                compile_specs = getattr(attr, "compile_specs", None)
247                                raise SpecViolationError(
248                                    f"Invalid get_attr type {type(attr)}. \n"
249                                    f"LoweredBackendModule fields: "
250                                    f"backend_id(str) : {type(backend_id)}, "
251                                    f"processed_bytes(bytes) : {type(processed_bytes)}, "
252                                    f"compile_specs(list) : {type(compile_specs)}"
253                                )
254
255                    if not isinstance(attr, _allowed_getattr_types()):
256                        raise SpecViolationError(
257                            f"Invalid get_attr type {type(attr)}. \n"
258                            f"Valid get_attr types: {_allowed_getattr_types()}"
259                        )
260
261
262                elif node.op == "placeholder":
263                    _check_val(node)
264                # TODO(zhxchen17)
265                # elif node.op == "output":
266                #     _check_flattened_outputs()
267
268        self.check_additional(gm)
269
270
271class TrainingIRVerifier(Verifier):
272    dialect = "TRAINING"
273
274
275def _verify_exported_program_module_call_graph(exported_program) -> None:
276    module_call_graph = exported_program.module_call_graph
277    nodes = {
278        node.name for node in exported_program.graph.nodes
279    }
280    for entry in module_call_graph:
281        if entry.signature is not None:
282            for arg in entry.signature.inputs:
283                if arg.name and arg.name not in nodes:
284                    raise SpecViolationError(
285                        f"Input {arg.name} does not exist in the graph."
286                    )
287            for arg in entry.signature.outputs:
288                if arg.name and arg.name not in nodes:
289                    raise SpecViolationError(
290                        f"Output {arg.name} does not exist in the graph."
291                    )
292
293
294def _verify_exported_program_signature(exported_program) -> None:
295    # Check ExportedProgram signature matches
296    gs = exported_program.graph_signature
297
298    # Check every node in the signature exists in the graph
299    input_node_names = [node.name for node in exported_program.graph.nodes if node.op == "placeholder"]
300
301    if len(input_node_names) != len(gs.input_specs):
302        raise SpecViolationError(
303            f"Number of graph inputs ({len(input_node_names)}) "
304            f"does not match number of inputs in the graph signature ({len(gs.input_specs)})"
305        )
306
307    for input_spec, node in zip(gs.input_specs, input_node_names):
308        if isinstance(input_spec.arg, (TensorArgument, SymIntArgument)):
309            if input_spec.arg.name != node:
310                raise SpecViolationError(
311                    f"Input spec name {input_spec.arg.name} does not match node name {node}"
312                )
313
314        if input_spec.kind == InputKind.USER_INPUT:
315            continue
316
317        elif input_spec.kind == InputKind.PARAMETER:
318            if not isinstance(input_spec.arg, TensorArgument):
319                raise SpecViolationError(
320                    f"Parameter {input_spec.name} is not a tensor argument. Found {input_spec.arg} instead."
321                )
322            if input_spec.target is None:
323                raise SpecViolationError(
324                    f"InputSpec for {input_spec.name} has no target."
325                )
326
327            param = input_spec.target
328            if param not in exported_program.state_dict:
329                raise SpecViolationError(
330                    f"Parameter {param} is not in the state dict."
331                )
332
333            if not isinstance(exported_program.state_dict[param], torch.nn.Parameter):
334                raise SpecViolationError(
335                    f"State dict entry for parameter {param} is not an instance of torch.nn.Parameter."
336                )
337
338        elif input_spec.kind == InputKind.BUFFER:
339            if not isinstance(input_spec.arg, TensorArgument):
340                raise SpecViolationError(
341                    f"Buffer {input_spec.name} is not a tensor argument. Found {input_spec.arg} instead."
342                )
343            if input_spec.target is None:
344                raise SpecViolationError(
345                    f"InputSpec for {input_spec.name} has no target."
346                )
347
348            buffer = input_spec.target
349            if input_spec.persistent is None:
350                raise SpecViolationError(
351                    f"Buffer {buffer} is missing a persistence flag"
352                )
353
354            if input_spec.persistent is True and buffer not in exported_program.state_dict:
355                raise SpecViolationError(
356                    f"Buffer {buffer} is not in the state dict."
357                )
358
359            if input_spec.persistent is False and buffer in exported_program.state_dict:
360                raise SpecViolationError(
361                    f"Non-persistent buffer {buffer} is in the state dict, it should not be."
362                )
363        elif input_spec.kind == InputKind.CONSTANT_TENSOR:
364            if not isinstance(input_spec.arg, TensorArgument):
365                raise SpecViolationError(
366                    f"Constant tensor {input_spec.name} is not a tensor argument. Found {input_spec.arg} instead."
367                )
368            if input_spec.target is None:
369                raise SpecViolationError(
370                    f"InputSpec for {input_spec.name} has no target."
371                )
372
373            tensor_const = input_spec.target
374            if tensor_const not in exported_program.constants:
375                raise SpecViolationError(
376                    f"Constant tensor {tensor_const} is not in the constants dictionary."
377                )
378        elif input_spec.kind == InputKind.CUSTOM_OBJ:
379            if not isinstance(input_spec.arg, CustomObjArgument):
380                raise SpecViolationError(
381                    f"Custom object {input_spec.name} is not a custom object argument. Found {input_spec.arg} instead."
382                )
383            if input_spec.target is None:
384                raise SpecViolationError(
385                    f"InputSpec for {input_spec.name} has no target."
386                )
387
388            custom_obj = input_spec.target
389            if custom_obj not in exported_program.constants:
390                raise SpecViolationError(
391                    f"Custom object {custom_obj} is not in the constants dictionary."
392                )
393        elif input_spec.kind == InputKind.TOKEN:
394            if not isinstance(input_spec.arg, TokenArgument):
395                raise SpecViolationError(
396                    f"Constant tensor {input_spec.name} is not a tensor argument. Found {input_spec.arg} instead."
397                )
398        else:
399            raise SpecViolationError(
400                f"Unknown InputKind {input_spec.kind}."
401            )
402
403    # Check outputs
404    output_node = list(exported_program.graph.nodes)[-1]
405    assert output_node.op == "output"
406    output_nodes = [
407        arg.name if isinstance(arg, torch.fx.Node) else arg
408        for arg in output_node.args[0]
409    ]
410
411    if len(output_nodes) != len(gs.output_specs):
412        raise SpecViolationError(
413            f"Number of output nodes {len(output_nodes)} is different "
414            "Than the number of outputs specified by the graph signature: \n"
415            f"Number of mutated buffers: {len(gs.buffers_to_mutate)}. \n"
416            f"Number of user outputs: {len(gs.user_outputs)}. \n"
417        )
418
419    num_tokens = len(gs.output_tokens)
420    end = len(gs.buffers_to_mutate) + len(gs.user_inputs_to_mutate) + num_tokens
421    mutate_nodes: List[str] = output_nodes[num_tokens:end]
422    user_output_nodes = output_nodes[end:end + len(gs.user_outputs)]
423
424    for mutation_node in mutate_nodes:
425        if mutation_node in gs.buffers_to_mutate:
426            if gs.buffers_to_mutate[mutation_node] not in gs.buffers:
427                raise SpecViolationError(
428                    f"Buffer output {mutation_node} does not point to a buffer that exists. \n"
429                    f"Dict of buffers that are mutated, in order: {gs.buffers_to_mutate} \n"
430                    f"Buffer nodes available: {gs.buffers} \n"
431                )
432        elif mutation_node in gs.user_inputs_to_mutate:
433            if gs.user_inputs_to_mutate[mutation_node] not in gs.user_inputs:
434                raise SpecViolationError(
435                    f"User input output {mutation_node} does not point to a user input that exists. \n"
436                    f"Dict of user inputs that are mutated, in order: {gs.user_inputs_to_mutate} \n"
437                    f"User input nodes available: {gs.user_inputs} \n")
438        else:
439            raise SpecViolationError(
440                f"Mutation node {mutation_node} is neither a buffer nor a user input. "
441                f"Buffers to mutate: {gs.buffers_to_mutate}, User inputs to mutate: {gs.user_inputs_to_mutate}"
442            )
443
444    for user_output_node, user_output_name in zip(user_output_nodes, gs.user_outputs):
445        if user_output_node != user_output_name:
446            raise SpecViolationError(
447                f"User output {user_output_node} is not in the correct "
448                "order or is not found in the "
449                f"exported program's user_output list: {gs.user_outputs}. "
450            )
451
452
453def load_verifier(dialect: str) -> Type[Verifier]:
454    if dialect == "ATEN" or dialect == "":
455        return _VerifierMeta._registry.get(dialect, Verifier)
456    return _VerifierMeta._registry[dialect]
457