1# Owner(s): ["module: unknown"] 2 3from typing import Dict, Any, Tuple 4from torch.ao.pruning import BaseSparsifier 5import torch 6import torch.nn.functional as F 7from torch import nn 8 9class ImplementedSparsifier(BaseSparsifier): 10 def __init__(self, **kwargs: Dict[str, Any]) -> None: 11 super().__init__(defaults=kwargs) 12 13 def update_mask(self, module: nn.Module, tensor_name: str, **kwargs: Dict[str, Any]) -> None: 14 module.parametrizations.weight[0].mask[0] = 0 15 linear_state = self.state['linear1.weight'] 16 linear_state['step_count'] = linear_state.get('step_count', 0) + 1 17 18 19class MockSparseLinear(nn.Linear): 20 """ 21 This class is a MockSparseLinear class to check convert functionality. 22 It is the same as a normal Linear layer, except with a different type, as 23 well as an additional from_dense method. 24 """ 25 @classmethod 26 def from_dense(cls, mod: nn.Linear) -> 'MockSparseLinear': 27 """ 28 """ 29 linear = cls(mod.in_features, 30 mod.out_features) 31 return linear 32 33 34def rows_are_subset(subset_tensor: torch.Tensor, superset_tensor: torch.Tensor) -> bool: 35 """ 36 Checks to see if all rows in subset tensor are present in the superset tensor 37 """ 38 i = 0 39 for row in subset_tensor: 40 while i < len(superset_tensor): 41 if not torch.equal(row, superset_tensor[i]): 42 i += 1 43 else: 44 break 45 else: 46 return False 47 return True 48 49 50class SimpleLinear(nn.Module): 51 r"""Model with only Linear layers without biases, some wrapped in a Sequential, 52 some following the Sequential. Used to test basic pruned Linear-Linear fusion.""" 53 54 def __init__(self) -> None: 55 super().__init__() 56 self.seq = nn.Sequential( 57 nn.Linear(7, 5, bias=False), 58 nn.Linear(5, 6, bias=False), 59 nn.Linear(6, 4, bias=False), 60 ) 61 self.linear1 = nn.Linear(4, 4, bias=False) 62 self.linear2 = nn.Linear(4, 10, bias=False) 63 64 def forward(self, x: torch.Tensor) -> torch.Tensor: 65 x = self.seq(x) 66 x = self.linear1(x) 67 x = self.linear2(x) 68 return x 69 70 71class LinearBias(nn.Module): 72 r"""Model with only Linear layers, alternating layers with biases, 73 wrapped in a Sequential. Used to test pruned Linear-Bias-Linear fusion.""" 74 75 def __init__(self) -> None: 76 super().__init__() 77 self.seq = nn.Sequential( 78 nn.Linear(7, 5, bias=True), 79 nn.Linear(5, 6, bias=False), 80 nn.Linear(6, 3, bias=True), 81 nn.Linear(3, 3, bias=True), 82 nn.Linear(3, 10, bias=False), 83 ) 84 85 def forward(self, x: torch.Tensor) -> torch.Tensor: 86 x = self.seq(x) 87 return x 88 89 90class LinearActivation(nn.Module): 91 r"""Model with only Linear layers, some with bias, some in a Sequential and some following. 92 Activation functions modules in between each Linear in the Sequential, and each outside layer. 93 Used to test pruned Linear(Bias)-Activation-Linear fusion.""" 94 95 def __init__(self) -> None: 96 super().__init__() 97 self.seq = nn.Sequential( 98 nn.Linear(7, 5, bias=True), 99 nn.ReLU(), 100 nn.Linear(5, 6, bias=False), 101 nn.Tanh(), 102 nn.Linear(6, 4, bias=True), 103 ) 104 self.linear1 = nn.Linear(4, 3, bias=True) 105 self.act1 = nn.ReLU() 106 self.linear2 = nn.Linear(3, 10, bias=False) 107 self.act2 = nn.Tanh() 108 109 def forward(self, x: torch.Tensor) -> torch.Tensor: 110 x = self.seq(x) 111 x = self.linear1(x) 112 x = self.act1(x) 113 x = self.linear2(x) 114 x = self.act2(x) 115 return x 116 117 118class LinearActivationFunctional(nn.Module): 119 r"""Model with only Linear layers, some with bias, some in a Sequential and some following. 120 Activation functions modules in between each Linear in the Sequential, and functional 121 activationals are called in between each outside layer. 122 Used to test pruned Linear(Bias)-Activation-Linear fusion.""" 123 124 def __init__(self) -> None: 125 super().__init__() 126 self.seq = nn.Sequential( 127 nn.Linear(7, 5, bias=True), 128 nn.ReLU(), 129 nn.Linear(5, 6, bias=False), 130 nn.ReLU(), 131 nn.Linear(6, 4, bias=True), 132 ) 133 self.linear1 = nn.Linear(4, 3, bias=True) 134 self.linear2 = nn.Linear(3, 8, bias=False) 135 self.linear3 = nn.Linear(8, 10, bias=False) 136 self.act1 = nn.ReLU() 137 138 def forward(self, x: torch.Tensor) -> torch.Tensor: 139 x = self.seq(x) 140 x = self.linear1(x) 141 x = F.relu(x) 142 x = self.linear2(x) 143 x = F.relu(x) 144 x = self.linear3(x) 145 x = F.relu(x) 146 return x 147 148 149class SimpleConv2d(nn.Module): 150 r"""Model with only Conv2d layers, all without bias, some in a Sequential and some following. 151 Used to test pruned Conv2d-Conv2d fusion.""" 152 153 def __init__(self) -> None: 154 super().__init__() 155 self.seq = nn.Sequential( 156 nn.Conv2d(1, 32, 3, 1, bias=False), 157 nn.Conv2d(32, 64, 3, 1, bias=False), 158 ) 159 self.conv2d1 = nn.Conv2d(64, 48, 3, 1, bias=False) 160 self.conv2d2 = nn.Conv2d(48, 52, 3, 1, bias=False) 161 162 def forward(self, x: torch.Tensor) -> torch.Tensor: 163 x = self.seq(x) 164 x = self.conv2d1(x) 165 x = self.conv2d2(x) 166 return x 167 168 169class Conv2dBias(nn.Module): 170 r"""Model with only Conv2d layers, some with bias, some in a Sequential and some outside. 171 Used to test pruned Conv2d-Bias-Conv2d fusion.""" 172 173 def __init__(self) -> None: 174 super().__init__() 175 self.seq = nn.Sequential( 176 nn.Conv2d(1, 32, 3, 1, bias=True), 177 nn.Conv2d(32, 32, 3, 1, bias=True), 178 nn.Conv2d(32, 64, 3, 1, bias=False), 179 ) 180 self.conv2d1 = nn.Conv2d(64, 48, 3, 1, bias=True) 181 self.conv2d2 = nn.Conv2d(48, 52, 3, 1, bias=False) 182 183 def forward(self, x: torch.Tensor) -> torch.Tensor: 184 x = self.seq(x) 185 x = self.conv2d1(x) 186 x = self.conv2d2(x) 187 return x 188 189 190class Conv2dActivation(nn.Module): 191 r"""Model with only Conv2d layers, some with bias, some in a Sequential and some following. 192 Activation function modules in between each Sequential layer, functional activations called 193 in-between each outside layer. 194 Used to test pruned Conv2d-Bias-Activation-Conv2d fusion.""" 195 196 def __init__(self) -> None: 197 super().__init__() 198 self.seq = nn.Sequential( 199 nn.Conv2d(1, 32, 3, 1, bias=True), 200 nn.ReLU(), 201 nn.Conv2d(32, 64, 3, 1, bias=True), 202 nn.Tanh(), 203 nn.Conv2d(64, 64, 3, 1, bias=False), 204 nn.ReLU(), 205 ) 206 self.conv2d1 = nn.Conv2d(64, 48, 3, 1, bias=False) 207 self.conv2d2 = nn.Conv2d(48, 52, 3, 1, bias=True) 208 209 def forward(self, x: torch.Tensor) -> torch.Tensor: 210 x = self.seq(x) 211 x = self.conv2d1(x) 212 x = F.relu(x) 213 x = self.conv2d2(x) 214 x = F.hardtanh(x) 215 return x 216 217 218class Conv2dPadBias(nn.Module): 219 r"""Model with only Conv2d layers, all with bias and some with padding > 0, 220 some in a Sequential and some following. Activation function modules in between each layer. 221 Used to test that bias is propagated correctly in the special case of 222 pruned Conv2d-Bias-(Activation)Conv2d fusion, when the second Conv2d layer has padding > 0.""" 223 224 def __init__(self) -> None: 225 super().__init__() 226 self.seq = nn.Sequential( 227 nn.Conv2d(1, 32, 3, 1, padding=1, bias=True), 228 nn.ReLU(), 229 nn.Conv2d(32, 32, 3, 1, bias=False), 230 nn.ReLU(), 231 nn.Conv2d(32, 32, 3, 1, padding=1, bias=True), 232 nn.ReLU(), 233 nn.Conv2d(32, 32, 3, 1, padding=1, bias=True), 234 nn.ReLU(), 235 nn.Conv2d(32, 64, 3, 1, bias=True), 236 nn.Tanh(), 237 ) 238 self.conv2d1 = nn.Conv2d(64, 48, 3, 1, padding=1, bias=True) 239 self.act1 = nn.ReLU() 240 self.conv2d2 = nn.Conv2d(48, 52, 3, 1, padding=1, bias=True) 241 self.act2 = nn.Tanh() 242 243 def forward(self, x: torch.Tensor) -> torch.Tensor: 244 x = self.seq(x) 245 x = self.conv2d1(x) 246 x = self.act1(x) 247 x = self.conv2d2(x) 248 x = self.act2(x) 249 return x 250 251 252class Conv2dPool(nn.Module): 253 r"""Model with only Conv2d layers, all with bias, some in a Sequential and some following. 254 Activation function modules in between each layer, Pool2d modules in between each layer. 255 Used to test pruned Conv2d-Pool2d-Conv2d fusion.""" 256 257 def __init__(self) -> None: 258 super().__init__() 259 self.seq = nn.Sequential( 260 nn.Conv2d(1, 32, kernel_size=3, padding=1, bias=True), 261 nn.MaxPool2d(kernel_size=2, stride=2, padding=1), 262 nn.ReLU(), 263 nn.Conv2d(32, 64, kernel_size=3, padding=1, bias=True), 264 nn.Tanh(), 265 nn.AvgPool2d(kernel_size=2, stride=2, padding=1), 266 ) 267 self.conv2d1 = nn.Conv2d(64, 48, kernel_size=3, padding=1, bias=True) 268 self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2, padding=1) 269 self.af1 = nn.ReLU() 270 self.conv2d2 = nn.Conv2d(48, 52, kernel_size=3, padding=1, bias=True) 271 self.conv2d3 = nn.Conv2d(52, 52, kernel_size=3, padding=1, bias=True) 272 273 def forward(self, x: torch.Tensor) -> torch.Tensor: 274 x = self.seq(x) 275 x = self.conv2d1(x) 276 x = self.maxpool(x) 277 x = self.af1(x) 278 x = self.conv2d2(x) 279 x = F.avg_pool2d(x, kernel_size=2, stride=2, padding=1) 280 x = F.relu(x) 281 x = self.conv2d3(x) 282 return x 283 284 285class Conv2dPoolFlattenFunctional(nn.Module): 286 r"""Model with Conv2d layers, all with bias, some in a Sequential and some following, and then a Pool2d 287 and a functional Flatten followed by a Linear layer. 288 Activation functions and Pool2ds in between each layer also. 289 Used to test pruned Conv2d-Pool2d-Flatten-Linear fusion.""" 290 291 def __init__(self) -> None: 292 super().__init__() 293 self.seq = nn.Sequential( 294 nn.Conv2d(1, 3, kernel_size=3, padding=1, bias=True), 295 nn.MaxPool2d(kernel_size=2, stride=2, padding=1), 296 nn.ReLU(), 297 nn.Conv2d(3, 5, kernel_size=3, padding=1, bias=True), 298 nn.Tanh(), 299 nn.AvgPool2d(kernel_size=2, stride=2, padding=1), 300 ) 301 self.conv2d1 = nn.Conv2d(5, 7, kernel_size=3, padding=1, bias=True) 302 self.af1 = nn.ReLU() 303 self.conv2d2 = nn.Conv2d(7, 11, kernel_size=3, padding=1, bias=True) 304 self.avg_pool = nn.AdaptiveAvgPool2d((1, 1)) 305 self.fc = nn.Linear(11, 13, bias=True) 306 307 def forward(self, x: torch.Tensor) -> torch.Tensor: 308 x = self.seq(x) 309 x = self.conv2d1(x) 310 x = F.max_pool2d(x, kernel_size=2, stride=2, padding=1) 311 x = self.af1(x) 312 x = self.conv2d2(x) 313 x = self.avg_pool(x) 314 x = torch.flatten(x, 1) # test functional flatten 315 x = self.fc(x) 316 return x 317 318 319class Conv2dPoolFlatten(nn.Module): 320 r"""Model with Conv2d layers, all with bias, some in a Sequential and some following, and then a Pool2d 321 and a Flatten module followed by a Linear layer. 322 Activation functions and Pool2ds in between each layer also. 323 Used to test pruned Conv2d-Pool2d-Flatten-Linear fusion.""" 324 325 def __init__(self) -> None: 326 super().__init__() 327 self.seq = nn.Sequential( 328 nn.Conv2d(1, 3, kernel_size=3, padding=1, bias=True), 329 nn.MaxPool2d(kernel_size=2, stride=2, padding=1), 330 nn.ReLU(), 331 nn.Conv2d(3, 5, kernel_size=3, padding=1, bias=True), 332 nn.Tanh(), 333 nn.AvgPool2d(kernel_size=2, stride=2, padding=1), 334 ) 335 self.conv2d1 = nn.Conv2d(5, 7, kernel_size=3, padding=1, bias=True) 336 self.af1 = nn.ReLU() 337 self.conv2d2 = nn.Conv2d(7, 11, kernel_size=3, padding=1, bias=True) 338 self.avg_pool = nn.AdaptiveAvgPool2d((2, 2)) 339 self.flatten = nn.Flatten() 340 self.fc = nn.Linear(44, 13, bias=True) 341 342 def forward(self, x: torch.Tensor) -> torch.Tensor: 343 x = self.seq(x) 344 x = self.conv2d1(x) 345 x = F.max_pool2d(x, kernel_size=2, stride=2, padding=1) 346 x = self.af1(x) 347 x = self.conv2d2(x) 348 x = self.avg_pool(x) 349 x = self.flatten(x) 350 x = self.fc(x) 351 return x 352 353 354class LSTMLinearModel(nn.Module): 355 """Container module with an encoder, a recurrent module, and a linear.""" 356 357 def __init__( 358 self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int 359 ) -> None: 360 super().__init__() 361 self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers) 362 self.linear = nn.Linear(hidden_dim, output_dim) 363 364 def forward(self, input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 365 output, hidden = self.lstm(input) 366 decoded = self.linear(output) 367 return decoded, output 368 369 370class LSTMLayerNormLinearModel(nn.Module): 371 """Container module with an LSTM, a LayerNorm, and a linear.""" 372 373 def __init__( 374 self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int 375 ) -> None: 376 super().__init__() 377 self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers) 378 self.norm = nn.LayerNorm(hidden_dim) 379 self.linear = nn.Linear(hidden_dim, output_dim) 380 381 def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 382 x, state = self.lstm(x) 383 x = self.norm(x) 384 x = self.linear(x) 385 return x, state 386