1# Owner(s): ["module: unknown"] 2 3 4import logging 5 6import torch 7from torch import nn 8from torch.ao.pruning.sparsifier import utils 9from torch.nn.utils import parametrize 10from torch.testing._internal.common_utils import TestCase 11 12 13logging.basicConfig( 14 format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO 15) 16 17 18class ModelUnderTest(nn.Module): 19 def __init__(self, bias=True): 20 super().__init__() 21 self.linear = nn.Linear(16, 16, bias=bias) 22 self.seq = nn.Sequential( 23 nn.Linear(16, 16, bias=bias), nn.Linear(16, 16, bias=bias) 24 ) 25 26 # Make sure the weights are not random 27 self.linear.weight = nn.Parameter(torch.zeros_like(self.linear.weight) + 1.0) 28 self.seq[0].weight = nn.Parameter(torch.zeros_like(self.seq[0].weight) + 2.0) 29 self.seq[1].weight = nn.Parameter(torch.zeros_like(self.seq[1].weight) + 3.0) 30 if bias: 31 self.linear = nn.Parameter(torch.zeros_like(self.linear.bias) + 10.0) 32 self.seq[0] = nn.Parameter(torch.zeros_like(self.seq[0].bias) + 20.0) 33 self.seq[0] = nn.Parameter(torch.zeros_like(self.seq[0].bias) + 30.0) 34 35 def forward(self, x): 36 x = self.linear(x) 37 x = self.seq(x) 38 return x 39 40 41class TestFakeSparsity(TestCase): 42 def test_masking_logic(self): 43 model = nn.Linear(16, 16, bias=False) 44 model.weight = nn.Parameter(torch.eye(16)) 45 x = torch.randn(3, 16) 46 self.assertEqual(torch.mm(x, torch.eye(16)), model(x)) 47 48 mask = torch.zeros(16, 16) 49 sparsity = utils.FakeSparsity(mask) 50 parametrize.register_parametrization(model, "weight", sparsity) 51 52 x = torch.randn(3, 16) 53 self.assertEqual(torch.zeros(3, 16), model(x)) 54 55 def test_weights_parametrized(self): 56 model = ModelUnderTest(bias=False) 57 58 assert not hasattr(model.linear, "parametrizations") 59 assert not hasattr(model.seq[0], "parametrizations") 60 assert not hasattr(model.seq[1], "parametrizations") 61 mask = torch.eye(16) 62 parametrize.register_parametrization( 63 model.linear, "weight", utils.FakeSparsity(mask) 64 ) 65 mask = torch.eye(16) 66 parametrize.register_parametrization( 67 model.seq[0], "weight", utils.FakeSparsity(mask) 68 ) 69 mask = torch.eye(16) 70 parametrize.register_parametrization( 71 model.seq[1], "weight", utils.FakeSparsity(mask) 72 ) 73 74 assert hasattr(model.linear, "parametrizations") 75 assert parametrize.is_parametrized(model.linear, "weight") 76 assert hasattr(model.seq[0], "parametrizations") 77 assert parametrize.is_parametrized(model.linear, "weight") 78 assert hasattr(model.seq[1], "parametrizations") 79 assert parametrize.is_parametrized(model.linear, "weight") 80 81 def test_state_dict_preserved(self): 82 model_save = ModelUnderTest(bias=False) 83 84 mask = torch.eye(16) 85 parametrize.register_parametrization( 86 model_save.linear, "weight", utils.FakeSparsity(mask) 87 ) 88 mask = torch.eye(16) 89 parametrize.register_parametrization( 90 model_save.seq[0], "weight", utils.FakeSparsity(mask) 91 ) 92 mask = torch.eye(16) 93 parametrize.register_parametrization( 94 model_save.seq[1], "weight", utils.FakeSparsity(mask) 95 ) 96 state_dict = model_save.state_dict() 97 98 model_load = ModelUnderTest(bias=False) 99 mask = torch.zeros(model_load.linear.weight.shape) 100 parametrize.register_parametrization( 101 model_load.linear, "weight", utils.FakeSparsity(mask) 102 ) 103 mask = torch.zeros(model_load.seq[0].weight.shape) 104 parametrize.register_parametrization( 105 model_load.seq[0], "weight", utils.FakeSparsity(mask) 106 ) 107 mask = torch.zeros(model_load.seq[1].weight.shape) 108 parametrize.register_parametrization( 109 model_load.seq[1], "weight", utils.FakeSparsity(mask) 110 ) 111 # Keep this strict, as we are not loading the 'mask' 112 model_load.load_state_dict(state_dict, strict=False) 113 114 # Check the parametrizations are preserved 115 assert hasattr(model_load.linear, "parametrizations") 116 assert parametrize.is_parametrized(model_load.linear, "weight") 117 assert hasattr(model_load.seq[0], "parametrizations") 118 assert parametrize.is_parametrized(model_load.linear, "weight") 119 assert hasattr(model_load.seq[1], "parametrizations") 120 assert parametrize.is_parametrized(model_load.linear, "weight") 121 122 # Check the weights are preserved 123 self.assertEqual( 124 model_save.linear.parametrizations["weight"].original, 125 model_load.linear.parametrizations["weight"].original, 126 ) 127 self.assertEqual( 128 model_save.seq[0].parametrizations["weight"].original, 129 model_load.seq[0].parametrizations["weight"].original, 130 ) 131 self.assertEqual( 132 model_save.seq[1].parametrizations["weight"].original, 133 model_load.seq[1].parametrizations["weight"].original, 134 ) 135 136 # Check the masks are not preserved in the state_dict 137 # We store the state_dicts in the sparsifier, not in the model itself. 138 # TODO: Need to find a clean way of exporting the parametrized model 139 self.assertNotEqual( 140 model_save.linear.parametrizations["weight"][0].mask, 141 model_load.linear.parametrizations["weight"][0].mask, 142 ) 143 self.assertNotEqual( 144 model_save.seq[0].parametrizations["weight"][0].mask, 145 model_load.seq[0].parametrizations["weight"][0].mask, 146 ) 147 self.assertNotEqual( 148 model_save.seq[1].parametrizations["weight"][0].mask, 149 model_load.seq[1].parametrizations["weight"][0].mask, 150 ) 151 152 def test_jit_trace(self): 153 model = ModelUnderTest(bias=False) 154 155 mask = torch.eye(16) 156 parametrize.register_parametrization( 157 model.linear, "weight", utils.FakeSparsity(mask) 158 ) 159 mask = torch.eye(16) 160 parametrize.register_parametrization( 161 model.seq[0], "weight", utils.FakeSparsity(mask) 162 ) 163 mask = torch.eye(16) 164 parametrize.register_parametrization( 165 model.seq[1], "weight", utils.FakeSparsity(mask) 166 ) 167 168 # Tracing 169 example_x = torch.ones(3, 16) 170 model_trace = torch.jit.trace_module(model, {"forward": example_x}) 171 172 x = torch.randn(3, 16) 173 y = model(x) 174 y_hat = model_trace(x) 175 self.assertEqual(y_hat, y) 176