xref: /aosp_15_r20/external/pytorch/torch/testing/_internal/common_pruning.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: unknown"]
2
3from typing import Dict, Any, Tuple
4from torch.ao.pruning import BaseSparsifier
5import torch
6import torch.nn.functional as F
7from torch import nn
8
9class ImplementedSparsifier(BaseSparsifier):
10    def __init__(self, **kwargs: Dict[str, Any]) -> None:
11        super().__init__(defaults=kwargs)
12
13    def update_mask(self, module: nn.Module, tensor_name: str, **kwargs: Dict[str, Any]) -> None:
14        module.parametrizations.weight[0].mask[0] = 0
15        linear_state = self.state['linear1.weight']
16        linear_state['step_count'] = linear_state.get('step_count', 0) + 1
17
18
19class MockSparseLinear(nn.Linear):
20    """
21    This class is a MockSparseLinear class to check convert functionality.
22    It is the same as a normal Linear layer, except with a different type, as
23    well as an additional from_dense method.
24    """
25    @classmethod
26    def from_dense(cls, mod: nn.Linear) -> 'MockSparseLinear':
27        """
28        """
29        linear = cls(mod.in_features,
30                     mod.out_features)
31        return linear
32
33
34def rows_are_subset(subset_tensor: torch.Tensor, superset_tensor: torch.Tensor) -> bool:
35    """
36    Checks to see if all rows in subset tensor are present in the superset tensor
37    """
38    i = 0
39    for row in subset_tensor:
40        while i < len(superset_tensor):
41            if not torch.equal(row, superset_tensor[i]):
42                i += 1
43            else:
44                break
45        else:
46            return False
47    return True
48
49
50class SimpleLinear(nn.Module):
51    r"""Model with only Linear layers without biases, some wrapped in a Sequential,
52    some following the Sequential. Used to test basic pruned Linear-Linear fusion."""
53
54    def __init__(self) -> None:
55        super().__init__()
56        self.seq = nn.Sequential(
57            nn.Linear(7, 5, bias=False),
58            nn.Linear(5, 6, bias=False),
59            nn.Linear(6, 4, bias=False),
60        )
61        self.linear1 = nn.Linear(4, 4, bias=False)
62        self.linear2 = nn.Linear(4, 10, bias=False)
63
64    def forward(self, x: torch.Tensor) -> torch.Tensor:
65        x = self.seq(x)
66        x = self.linear1(x)
67        x = self.linear2(x)
68        return x
69
70
71class LinearBias(nn.Module):
72    r"""Model with only Linear layers, alternating layers with biases,
73    wrapped in a Sequential. Used to test pruned Linear-Bias-Linear fusion."""
74
75    def __init__(self) -> None:
76        super().__init__()
77        self.seq = nn.Sequential(
78            nn.Linear(7, 5, bias=True),
79            nn.Linear(5, 6, bias=False),
80            nn.Linear(6, 3, bias=True),
81            nn.Linear(3, 3, bias=True),
82            nn.Linear(3, 10, bias=False),
83        )
84
85    def forward(self, x: torch.Tensor) -> torch.Tensor:
86        x = self.seq(x)
87        return x
88
89
90class LinearActivation(nn.Module):
91    r"""Model with only Linear layers, some with bias, some in a Sequential and some following.
92    Activation functions modules in between each Linear in the Sequential, and each outside layer.
93    Used to test pruned Linear(Bias)-Activation-Linear fusion."""
94
95    def __init__(self) -> None:
96        super().__init__()
97        self.seq = nn.Sequential(
98            nn.Linear(7, 5, bias=True),
99            nn.ReLU(),
100            nn.Linear(5, 6, bias=False),
101            nn.Tanh(),
102            nn.Linear(6, 4, bias=True),
103        )
104        self.linear1 = nn.Linear(4, 3, bias=True)
105        self.act1 = nn.ReLU()
106        self.linear2 = nn.Linear(3, 10, bias=False)
107        self.act2 = nn.Tanh()
108
109    def forward(self, x: torch.Tensor) -> torch.Tensor:
110        x = self.seq(x)
111        x = self.linear1(x)
112        x = self.act1(x)
113        x = self.linear2(x)
114        x = self.act2(x)
115        return x
116
117
118class LinearActivationFunctional(nn.Module):
119    r"""Model with only Linear layers, some with bias, some in a Sequential and some following.
120    Activation functions modules in between each Linear in the Sequential, and functional
121    activationals are called in between each outside layer.
122    Used to test pruned Linear(Bias)-Activation-Linear fusion."""
123
124    def __init__(self) -> None:
125        super().__init__()
126        self.seq = nn.Sequential(
127            nn.Linear(7, 5, bias=True),
128            nn.ReLU(),
129            nn.Linear(5, 6, bias=False),
130            nn.ReLU(),
131            nn.Linear(6, 4, bias=True),
132        )
133        self.linear1 = nn.Linear(4, 3, bias=True)
134        self.linear2 = nn.Linear(3, 8, bias=False)
135        self.linear3 = nn.Linear(8, 10, bias=False)
136        self.act1 = nn.ReLU()
137
138    def forward(self, x: torch.Tensor) -> torch.Tensor:
139        x = self.seq(x)
140        x = self.linear1(x)
141        x = F.relu(x)
142        x = self.linear2(x)
143        x = F.relu(x)
144        x = self.linear3(x)
145        x = F.relu(x)
146        return x
147
148
149class SimpleConv2d(nn.Module):
150    r"""Model with only Conv2d layers, all without bias, some in a Sequential and some following.
151    Used to test pruned Conv2d-Conv2d fusion."""
152
153    def __init__(self) -> None:
154        super().__init__()
155        self.seq = nn.Sequential(
156            nn.Conv2d(1, 32, 3, 1, bias=False),
157            nn.Conv2d(32, 64, 3, 1, bias=False),
158        )
159        self.conv2d1 = nn.Conv2d(64, 48, 3, 1, bias=False)
160        self.conv2d2 = nn.Conv2d(48, 52, 3, 1, bias=False)
161
162    def forward(self, x: torch.Tensor) -> torch.Tensor:
163        x = self.seq(x)
164        x = self.conv2d1(x)
165        x = self.conv2d2(x)
166        return x
167
168
169class Conv2dBias(nn.Module):
170    r"""Model with only Conv2d layers, some with bias, some in a Sequential and some outside.
171    Used to test pruned Conv2d-Bias-Conv2d fusion."""
172
173    def __init__(self) -> None:
174        super().__init__()
175        self.seq = nn.Sequential(
176            nn.Conv2d(1, 32, 3, 1, bias=True),
177            nn.Conv2d(32, 32, 3, 1, bias=True),
178            nn.Conv2d(32, 64, 3, 1, bias=False),
179        )
180        self.conv2d1 = nn.Conv2d(64, 48, 3, 1, bias=True)
181        self.conv2d2 = nn.Conv2d(48, 52, 3, 1, bias=False)
182
183    def forward(self, x: torch.Tensor) -> torch.Tensor:
184        x = self.seq(x)
185        x = self.conv2d1(x)
186        x = self.conv2d2(x)
187        return x
188
189
190class Conv2dActivation(nn.Module):
191    r"""Model with only Conv2d layers, some with bias, some in a Sequential and some following.
192    Activation function modules in between each Sequential layer, functional activations called
193    in-between each outside layer.
194    Used to test pruned Conv2d-Bias-Activation-Conv2d fusion."""
195
196    def __init__(self) -> None:
197        super().__init__()
198        self.seq = nn.Sequential(
199            nn.Conv2d(1, 32, 3, 1, bias=True),
200            nn.ReLU(),
201            nn.Conv2d(32, 64, 3, 1, bias=True),
202            nn.Tanh(),
203            nn.Conv2d(64, 64, 3, 1, bias=False),
204            nn.ReLU(),
205        )
206        self.conv2d1 = nn.Conv2d(64, 48, 3, 1, bias=False)
207        self.conv2d2 = nn.Conv2d(48, 52, 3, 1, bias=True)
208
209    def forward(self, x: torch.Tensor) -> torch.Tensor:
210        x = self.seq(x)
211        x = self.conv2d1(x)
212        x = F.relu(x)
213        x = self.conv2d2(x)
214        x = F.hardtanh(x)
215        return x
216
217
218class Conv2dPadBias(nn.Module):
219    r"""Model with only Conv2d layers, all with bias and some with padding > 0,
220    some in a Sequential and some following. Activation function modules in between each layer.
221    Used to test that bias is propagated correctly in the special case of
222    pruned Conv2d-Bias-(Activation)Conv2d fusion, when the second Conv2d layer has padding > 0."""
223
224    def __init__(self) -> None:
225        super().__init__()
226        self.seq = nn.Sequential(
227            nn.Conv2d(1, 32, 3, 1, padding=1, bias=True),
228            nn.ReLU(),
229            nn.Conv2d(32, 32, 3, 1, bias=False),
230            nn.ReLU(),
231            nn.Conv2d(32, 32, 3, 1, padding=1, bias=True),
232            nn.ReLU(),
233            nn.Conv2d(32, 32, 3, 1, padding=1, bias=True),
234            nn.ReLU(),
235            nn.Conv2d(32, 64, 3, 1, bias=True),
236            nn.Tanh(),
237        )
238        self.conv2d1 = nn.Conv2d(64, 48, 3, 1, padding=1, bias=True)
239        self.act1 = nn.ReLU()
240        self.conv2d2 = nn.Conv2d(48, 52, 3, 1, padding=1, bias=True)
241        self.act2 = nn.Tanh()
242
243    def forward(self, x: torch.Tensor) -> torch.Tensor:
244        x = self.seq(x)
245        x = self.conv2d1(x)
246        x = self.act1(x)
247        x = self.conv2d2(x)
248        x = self.act2(x)
249        return x
250
251
252class Conv2dPool(nn.Module):
253    r"""Model with only Conv2d layers, all with bias, some in a Sequential and some following.
254    Activation function modules in between each layer, Pool2d modules in between each layer.
255    Used to test pruned Conv2d-Pool2d-Conv2d fusion."""
256
257    def __init__(self) -> None:
258        super().__init__()
259        self.seq = nn.Sequential(
260            nn.Conv2d(1, 32, kernel_size=3, padding=1, bias=True),
261            nn.MaxPool2d(kernel_size=2, stride=2, padding=1),
262            nn.ReLU(),
263            nn.Conv2d(32, 64, kernel_size=3, padding=1, bias=True),
264            nn.Tanh(),
265            nn.AvgPool2d(kernel_size=2, stride=2, padding=1),
266        )
267        self.conv2d1 = nn.Conv2d(64, 48, kernel_size=3, padding=1, bias=True)
268        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2, padding=1)
269        self.af1 = nn.ReLU()
270        self.conv2d2 = nn.Conv2d(48, 52, kernel_size=3, padding=1, bias=True)
271        self.conv2d3 = nn.Conv2d(52, 52, kernel_size=3, padding=1, bias=True)
272
273    def forward(self, x: torch.Tensor) -> torch.Tensor:
274        x = self.seq(x)
275        x = self.conv2d1(x)
276        x = self.maxpool(x)
277        x = self.af1(x)
278        x = self.conv2d2(x)
279        x = F.avg_pool2d(x, kernel_size=2, stride=2, padding=1)
280        x = F.relu(x)
281        x = self.conv2d3(x)
282        return x
283
284
285class Conv2dPoolFlattenFunctional(nn.Module):
286    r"""Model with Conv2d layers, all with bias, some in a Sequential and some following, and then a Pool2d
287    and a functional Flatten followed by a Linear layer.
288    Activation functions and Pool2ds in between each layer also.
289    Used to test pruned Conv2d-Pool2d-Flatten-Linear fusion."""
290
291    def __init__(self) -> None:
292        super().__init__()
293        self.seq = nn.Sequential(
294            nn.Conv2d(1, 3, kernel_size=3, padding=1, bias=True),
295            nn.MaxPool2d(kernel_size=2, stride=2, padding=1),
296            nn.ReLU(),
297            nn.Conv2d(3, 5, kernel_size=3, padding=1, bias=True),
298            nn.Tanh(),
299            nn.AvgPool2d(kernel_size=2, stride=2, padding=1),
300        )
301        self.conv2d1 = nn.Conv2d(5, 7, kernel_size=3, padding=1, bias=True)
302        self.af1 = nn.ReLU()
303        self.conv2d2 = nn.Conv2d(7, 11, kernel_size=3, padding=1, bias=True)
304        self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
305        self.fc = nn.Linear(11, 13, bias=True)
306
307    def forward(self, x: torch.Tensor) -> torch.Tensor:
308        x = self.seq(x)
309        x = self.conv2d1(x)
310        x = F.max_pool2d(x, kernel_size=2, stride=2, padding=1)
311        x = self.af1(x)
312        x = self.conv2d2(x)
313        x = self.avg_pool(x)
314        x = torch.flatten(x, 1)  # test functional flatten
315        x = self.fc(x)
316        return x
317
318
319class Conv2dPoolFlatten(nn.Module):
320    r"""Model with Conv2d layers, all with bias, some in a Sequential and some following, and then a Pool2d
321    and a Flatten module followed by a Linear layer.
322    Activation functions and Pool2ds in between each layer also.
323    Used to test pruned Conv2d-Pool2d-Flatten-Linear fusion."""
324
325    def __init__(self) -> None:
326        super().__init__()
327        self.seq = nn.Sequential(
328            nn.Conv2d(1, 3, kernel_size=3, padding=1, bias=True),
329            nn.MaxPool2d(kernel_size=2, stride=2, padding=1),
330            nn.ReLU(),
331            nn.Conv2d(3, 5, kernel_size=3, padding=1, bias=True),
332            nn.Tanh(),
333            nn.AvgPool2d(kernel_size=2, stride=2, padding=1),
334        )
335        self.conv2d1 = nn.Conv2d(5, 7, kernel_size=3, padding=1, bias=True)
336        self.af1 = nn.ReLU()
337        self.conv2d2 = nn.Conv2d(7, 11, kernel_size=3, padding=1, bias=True)
338        self.avg_pool = nn.AdaptiveAvgPool2d((2, 2))
339        self.flatten = nn.Flatten()
340        self.fc = nn.Linear(44, 13, bias=True)
341
342    def forward(self, x: torch.Tensor) -> torch.Tensor:
343        x = self.seq(x)
344        x = self.conv2d1(x)
345        x = F.max_pool2d(x, kernel_size=2, stride=2, padding=1)
346        x = self.af1(x)
347        x = self.conv2d2(x)
348        x = self.avg_pool(x)
349        x = self.flatten(x)
350        x = self.fc(x)
351        return x
352
353
354class LSTMLinearModel(nn.Module):
355    """Container module with an encoder, a recurrent module, and a linear."""
356
357    def __init__(
358        self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int
359    ) -> None:
360        super().__init__()
361        self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers)
362        self.linear = nn.Linear(hidden_dim, output_dim)
363
364    def forward(self, input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
365        output, hidden = self.lstm(input)
366        decoded = self.linear(output)
367        return decoded, output
368
369
370class LSTMLayerNormLinearModel(nn.Module):
371    """Container module with an LSTM, a LayerNorm, and a linear."""
372
373    def __init__(
374        self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int
375    ) -> None:
376        super().__init__()
377        self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers)
378        self.norm = nn.LayerNorm(hidden_dim)
379        self.linear = nn.Linear(hidden_dim, output_dim)
380
381    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
382        x, state = self.lstm(x)
383        x = self.norm(x)
384        x = self.linear(x)
385        return x, state
386