1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-strict 8 9import sys 10import unittest 11 12import torch 13 14from executorch.exir.program._fake_program import ( 15 get_fake_program, 16 update_to_real_program, 17) 18from torch.export import export, ExportedProgram 19 20 21def get_exported_program() -> ExportedProgram: 22 class Linear(torch.nn.Module): 23 def __init__(self): 24 super().__init__() 25 self.linear = torch.nn.Linear(10, 10) 26 self.register_buffer("buf", torch.randn(10, 10), persistent=False) 27 28 def forward(self, arg) -> torch.Tensor: 29 return self.linear(arg) + self.buf 30 31 linear = Linear() 32 exported_program = export( 33 linear, 34 args=(torch.randn(10, 10),), 35 ).run_decompositions() 36 return exported_program 37 38 39class TestFakeProgram(unittest.TestCase): 40 def setUp(self) -> None: 41 super().setUp() 42 43 def test_fake_program(self) -> None: 44 exported_program = get_exported_program() 45 fake_program = get_fake_program(exported_program) 46 print(f"Exported program size: {sys.getsizeof(exported_program.state_dict)}") 47 print(f"Fake program size: {sys.getsizeof(fake_program.state_dict)}") 48 49 # Fake program deep copies attributes besides verifier, state_dict and constants. 50 self.assertEqual(exported_program.graph_signature, fake_program.graph_signature) 51 self.assertNotEqual( 52 id(exported_program.graph_signature), id(fake_program.graph_signature) 53 ) 54 self.assertEqual( 55 exported_program.module_call_graph, fake_program.module_call_graph 56 ) 57 self.assertNotEqual( 58 id(exported_program.module_call_graph), id(fake_program.module_call_graph) 59 ) 60 61 # Verifier is static. 62 self.assertEqual(exported_program.verifier, fake_program.verifier) 63 self.assertEqual(id(exported_program.verifier), id(fake_program.verifier)) 64 65 # Fake program uses fake tensors for the state dict. Size should be not be larger. 66 self.assertLessEqual( 67 sys.getsizeof(fake_program.state_dict), 68 sys.getsizeof(exported_program.state_dict), 69 ) 70 71 # Do not copy constants. 72 self.assertEqual(exported_program.constants, fake_program.constants) 73 self.assertEqual(id(exported_program.constants), id(fake_program.constants)) 74 75 update_to_real_program(fake_program, exported_program) 76 self.assertEqual(exported_program.state_dict, fake_program.state_dict) 77 self.assertEqual( 78 exported_program.state_dict.keys(), fake_program.state_dict.keys() 79 ) 80