xref: /aosp_15_r20/external/pytorch/test/ao/sparsity/test_parametrization.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: unknown"]
2
3
4import logging
5
6import torch
7from torch import nn
8from torch.ao.pruning.sparsifier import utils
9from torch.nn.utils import parametrize
10from torch.testing._internal.common_utils import TestCase
11
12
13logging.basicConfig(
14    format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO
15)
16
17
18class ModelUnderTest(nn.Module):
19    def __init__(self, bias=True):
20        super().__init__()
21        self.linear = nn.Linear(16, 16, bias=bias)
22        self.seq = nn.Sequential(
23            nn.Linear(16, 16, bias=bias), nn.Linear(16, 16, bias=bias)
24        )
25
26        # Make sure the weights are not random
27        self.linear.weight = nn.Parameter(torch.zeros_like(self.linear.weight) + 1.0)
28        self.seq[0].weight = nn.Parameter(torch.zeros_like(self.seq[0].weight) + 2.0)
29        self.seq[1].weight = nn.Parameter(torch.zeros_like(self.seq[1].weight) + 3.0)
30        if bias:
31            self.linear = nn.Parameter(torch.zeros_like(self.linear.bias) + 10.0)
32            self.seq[0] = nn.Parameter(torch.zeros_like(self.seq[0].bias) + 20.0)
33            self.seq[0] = nn.Parameter(torch.zeros_like(self.seq[0].bias) + 30.0)
34
35    def forward(self, x):
36        x = self.linear(x)
37        x = self.seq(x)
38        return x
39
40
41class TestFakeSparsity(TestCase):
42    def test_masking_logic(self):
43        model = nn.Linear(16, 16, bias=False)
44        model.weight = nn.Parameter(torch.eye(16))
45        x = torch.randn(3, 16)
46        self.assertEqual(torch.mm(x, torch.eye(16)), model(x))
47
48        mask = torch.zeros(16, 16)
49        sparsity = utils.FakeSparsity(mask)
50        parametrize.register_parametrization(model, "weight", sparsity)
51
52        x = torch.randn(3, 16)
53        self.assertEqual(torch.zeros(3, 16), model(x))
54
55    def test_weights_parametrized(self):
56        model = ModelUnderTest(bias=False)
57
58        assert not hasattr(model.linear, "parametrizations")
59        assert not hasattr(model.seq[0], "parametrizations")
60        assert not hasattr(model.seq[1], "parametrizations")
61        mask = torch.eye(16)
62        parametrize.register_parametrization(
63            model.linear, "weight", utils.FakeSparsity(mask)
64        )
65        mask = torch.eye(16)
66        parametrize.register_parametrization(
67            model.seq[0], "weight", utils.FakeSparsity(mask)
68        )
69        mask = torch.eye(16)
70        parametrize.register_parametrization(
71            model.seq[1], "weight", utils.FakeSparsity(mask)
72        )
73
74        assert hasattr(model.linear, "parametrizations")
75        assert parametrize.is_parametrized(model.linear, "weight")
76        assert hasattr(model.seq[0], "parametrizations")
77        assert parametrize.is_parametrized(model.linear, "weight")
78        assert hasattr(model.seq[1], "parametrizations")
79        assert parametrize.is_parametrized(model.linear, "weight")
80
81    def test_state_dict_preserved(self):
82        model_save = ModelUnderTest(bias=False)
83
84        mask = torch.eye(16)
85        parametrize.register_parametrization(
86            model_save.linear, "weight", utils.FakeSparsity(mask)
87        )
88        mask = torch.eye(16)
89        parametrize.register_parametrization(
90            model_save.seq[0], "weight", utils.FakeSparsity(mask)
91        )
92        mask = torch.eye(16)
93        parametrize.register_parametrization(
94            model_save.seq[1], "weight", utils.FakeSparsity(mask)
95        )
96        state_dict = model_save.state_dict()
97
98        model_load = ModelUnderTest(bias=False)
99        mask = torch.zeros(model_load.linear.weight.shape)
100        parametrize.register_parametrization(
101            model_load.linear, "weight", utils.FakeSparsity(mask)
102        )
103        mask = torch.zeros(model_load.seq[0].weight.shape)
104        parametrize.register_parametrization(
105            model_load.seq[0], "weight", utils.FakeSparsity(mask)
106        )
107        mask = torch.zeros(model_load.seq[1].weight.shape)
108        parametrize.register_parametrization(
109            model_load.seq[1], "weight", utils.FakeSparsity(mask)
110        )
111        # Keep this strict, as we are not loading the 'mask'
112        model_load.load_state_dict(state_dict, strict=False)
113
114        # Check the parametrizations are preserved
115        assert hasattr(model_load.linear, "parametrizations")
116        assert parametrize.is_parametrized(model_load.linear, "weight")
117        assert hasattr(model_load.seq[0], "parametrizations")
118        assert parametrize.is_parametrized(model_load.linear, "weight")
119        assert hasattr(model_load.seq[1], "parametrizations")
120        assert parametrize.is_parametrized(model_load.linear, "weight")
121
122        # Check the weights are preserved
123        self.assertEqual(
124            model_save.linear.parametrizations["weight"].original,
125            model_load.linear.parametrizations["weight"].original,
126        )
127        self.assertEqual(
128            model_save.seq[0].parametrizations["weight"].original,
129            model_load.seq[0].parametrizations["weight"].original,
130        )
131        self.assertEqual(
132            model_save.seq[1].parametrizations["weight"].original,
133            model_load.seq[1].parametrizations["weight"].original,
134        )
135
136        # Check the masks are not preserved in the state_dict
137        # We store the state_dicts in the sparsifier, not in the model itself.
138        # TODO: Need to find a clean way of exporting the parametrized model
139        self.assertNotEqual(
140            model_save.linear.parametrizations["weight"][0].mask,
141            model_load.linear.parametrizations["weight"][0].mask,
142        )
143        self.assertNotEqual(
144            model_save.seq[0].parametrizations["weight"][0].mask,
145            model_load.seq[0].parametrizations["weight"][0].mask,
146        )
147        self.assertNotEqual(
148            model_save.seq[1].parametrizations["weight"][0].mask,
149            model_load.seq[1].parametrizations["weight"][0].mask,
150        )
151
152    def test_jit_trace(self):
153        model = ModelUnderTest(bias=False)
154
155        mask = torch.eye(16)
156        parametrize.register_parametrization(
157            model.linear, "weight", utils.FakeSparsity(mask)
158        )
159        mask = torch.eye(16)
160        parametrize.register_parametrization(
161            model.seq[0], "weight", utils.FakeSparsity(mask)
162        )
163        mask = torch.eye(16)
164        parametrize.register_parametrization(
165            model.seq[1], "weight", utils.FakeSparsity(mask)
166        )
167
168        # Tracing
169        example_x = torch.ones(3, 16)
170        model_trace = torch.jit.trace_module(model, {"forward": example_x})
171
172        x = torch.randn(3, 16)
173        y = model(x)
174        y_hat = model_trace(x)
175        self.assertEqual(y_hat, y)
176