xref: /aosp_15_r20/external/executorch/exir/program/_fake_program.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-strict
8
9import copy
10from typing import Dict, Union
11
12import torch
13
14from torch._guards import detect_fake_mode
15from torch.export import ExportedProgram
16
17
18def get_fake_program(real_exported_program: ExportedProgram) -> ExportedProgram:
19    """Create a fake exported program. This uses fake tensors for the state dict
20    to prevent mutation, and points to the real constants, to avoid large memory
21    usage from copying when constants are large.
22
23    Args:
24        real_exported_program: the original exported program
25    Returns:
26        A new exported program, with fake tensors.
27    """
28    fake_mode = detect_fake_mode(
29        tuple(
30            node.meta["val"]
31            for node in real_exported_program.graph.nodes
32            if node.op == "placeholder"
33        )
34    )
35    if fake_mode is None:
36        raise AssertionError(
37            "Could not detect fake mode for graph: ", real_exported_program.graph
38        )
39
40    new_state_dict: Dict[str, Union[torch.Tensor, torch.nn.Parameter]] = {}
41
42    for key, tensor in real_exported_program.state_dict.items():
43        fake = fake_mode.from_tensor(tensor, static_shapes=True)
44        new_state_dict[key] = fake
45
46    gm = copy.deepcopy(real_exported_program.graph_module)
47    fake_exported_program = ExportedProgram(
48        root=gm,
49        graph=gm.graph,
50        graph_signature=copy.deepcopy(real_exported_program.graph_signature),
51        state_dict=new_state_dict,
52        range_constraints=copy.deepcopy(real_exported_program.range_constraints),
53        module_call_graph=copy.deepcopy(real_exported_program.module_call_graph),
54        constants=real_exported_program.constants,
55        verifiers=[real_exported_program.verifier],
56    )
57    return fake_exported_program
58
59
60def update_to_real_program(
61    fake_exported_program: ExportedProgram, real_exported_program: ExportedProgram
62) -> None:
63    """Update the fake exported program to point to the real state dict. Modifies the
64    fake exported program in-place.
65    """
66    for k, v in real_exported_program.state_dict.items():
67        fake_exported_program._state_dict[k] = v
68