1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-strict 8 9import copy 10from typing import Dict, Union 11 12import torch 13 14from torch._guards import detect_fake_mode 15from torch.export import ExportedProgram 16 17 18def get_fake_program(real_exported_program: ExportedProgram) -> ExportedProgram: 19 """Create a fake exported program. This uses fake tensors for the state dict 20 to prevent mutation, and points to the real constants, to avoid large memory 21 usage from copying when constants are large. 22 23 Args: 24 real_exported_program: the original exported program 25 Returns: 26 A new exported program, with fake tensors. 27 """ 28 fake_mode = detect_fake_mode( 29 tuple( 30 node.meta["val"] 31 for node in real_exported_program.graph.nodes 32 if node.op == "placeholder" 33 ) 34 ) 35 if fake_mode is None: 36 raise AssertionError( 37 "Could not detect fake mode for graph: ", real_exported_program.graph 38 ) 39 40 new_state_dict: Dict[str, Union[torch.Tensor, torch.nn.Parameter]] = {} 41 42 for key, tensor in real_exported_program.state_dict.items(): 43 fake = fake_mode.from_tensor(tensor, static_shapes=True) 44 new_state_dict[key] = fake 45 46 gm = copy.deepcopy(real_exported_program.graph_module) 47 fake_exported_program = ExportedProgram( 48 root=gm, 49 graph=gm.graph, 50 graph_signature=copy.deepcopy(real_exported_program.graph_signature), 51 state_dict=new_state_dict, 52 range_constraints=copy.deepcopy(real_exported_program.range_constraints), 53 module_call_graph=copy.deepcopy(real_exported_program.module_call_graph), 54 constants=real_exported_program.constants, 55 verifiers=[real_exported_program.verifier], 56 ) 57 return fake_exported_program 58 59 60def update_to_real_program( 61 fake_exported_program: ExportedProgram, real_exported_program: ExportedProgram 62) -> None: 63 """Update the fake exported program to point to the real state dict. Modifies the 64 fake exported program in-place. 65 """ 66 for k, v in real_exported_program.state_dict.items(): 67 fake_exported_program._state_dict[k] = v 68