1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-strict 8 9import itertools 10from typing import Any, List, Optional, Tuple, Union 11 12import executorch.exir as exir 13 14import torch # noqa: F401 15import torch.nn as nn 16from executorch.exir import to_edge 17from executorch.exir.lowered_backend_module import LoweredBackendModule 18from torch import Tensor 19from torch.export import export 20 21# TODO: add one more test for data dependent op plus repeat 22 23 24class TensorItem(nn.Module): 25 def __init__(self) -> None: 26 super().__init__() 27 28 def forward(self, arg1: torch.Tensor, arg2: torch.Tensor) -> torch.Tensor: 29 h = arg1.item() 30 w = arg2.item() 31 torch._check(h >= 2) 32 torch._check(h <= 100) 33 torch._check(w >= 2) 34 torch._check(w <= 100) 35 return torch.ones(int(h), int(w)) 36 37 def get_random_inputs(self) -> Tuple[torch.Tensor, torch.Tensor]: 38 return (torch.tensor(10), torch.tensor(20)) 39 40 41class Repeat(nn.Module): 42 def __init__(self) -> None: 43 super().__init__() 44 45 def forward( 46 self, arg1: torch.Tensor, arg2: torch.Tensor 47 ) -> Tuple[torch.Tensor, torch.Tensor]: 48 x = arg2.repeat(arg1.size(0), 1) 49 return x * x, arg2 + arg2 50 51 def get_random_inputs(self) -> Tuple[torch.Tensor, torch.Tensor]: 52 return (torch.rand(4), torch.rand(5)) 53 54 def get_dynamic_shape(self) -> Any: # pyre-ignore[3] 55 dim = torch.export.Dim("dim", max=10) 56 dim2 = torch.export.Dim("dim2", max=10) 57 return ({0: dim}, {0: dim2}) 58 59 60class ModelWithUnusedArg(nn.Module): 61 def __init__(self) -> None: 62 super().__init__() 63 64 def forward(self, arg1: torch.Tensor, arg2: torch.Tensor) -> torch.Tensor: 65 return torch.sin(arg1) 66 67 def get_random_inputs(self) -> Tuple[torch.Tensor, torch.Tensor]: 68 return (torch.rand(4), torch.rand(5)) 69 70 71class MLP(nn.Module): 72 def __init__(self, n_layer: int = 1, output_size: int = 1) -> None: 73 super().__init__() 74 self.n_layer = n_layer 75 self.output_size = output_size 76 # input shape [batch_size, n_layer+output_size] 77 # each linear layer reduce the activation dim 1 size by 1. 78 self.mlp = torch.nn.Sequential( 79 *itertools.chain( 80 *( 81 [nn.Linear(i + output_size, i - 1 + output_size)] 82 + ([nn.ReLU()] if i != 1 else []) 83 for i in range(n_layer, 0, -1) 84 ) 85 ) 86 ) 87 88 def forward(self, inputs: torch.Tensor) -> torch.Tensor: 89 return self.mlp(inputs) 90 91 def get_random_inputs(self) -> Tuple[torch.Tensor, ...]: 92 return (torch.rand(2, self.n_layer + self.output_size),) 93 94 95class Identity(nn.Module): 96 def __init__(self) -> None: 97 super().__init__() 98 99 def forward(self, input: Tensor) -> Tensor: 100 return torch.clone(input) 101 102 103class Reshape(nn.Module): 104 def __init__(self) -> None: 105 super().__init__() 106 107 def forward( 108 self, x: Tensor, *new_shape: Union[torch.Size, Tuple[int, ...], List[int]] 109 ) -> Tensor: 110 if len(new_shape) == 1 and ( 111 isinstance(new_shape[0], tuple) or isinstance(new_shape[0], list) 112 ): 113 return x.reshape(new_shape[0]) 114 assert isinstance(new_shape, Union[torch.Size, Tuple[int, ...], List[int]]) 115 return x.reshape(new_shape) 116 117 118class Transpose(nn.Module): 119 def __init__(self) -> None: 120 super().__init__() 121 122 def forward(self, x: Tensor, dim0: int, dim1: int) -> Tensor: 123 return x.transpose(dim0, dim1) 124 125 126class Mul(nn.Module): 127 def __init__(self) -> None: 128 super().__init__() 129 130 def forward(self, input: Tensor, other: Tensor) -> Tensor: 131 # or return torch.mul(input, other) 132 return input * other 133 134 def get_random_inputs(self) -> Tuple[Tensor, Tensor]: 135 return (torch.randn(3, 2), torch.randn(3, 2)) 136 137 138class ElementwiseAdd(nn.Module): 139 def __init__(self) -> None: 140 super().__init__() 141 142 def forward(self, x: Tensor, y: Tensor) -> Tensor: 143 return x + y 144 145 def get_random_inputs(self) -> Tuple[Tensor, Tensor]: 146 return (torch.randn(1, 3), torch.randn(1, 3)) 147 148 149class BasicSinMax(nn.Module): 150 def __init__(self) -> None: 151 super().__init__() 152 153 def forward(self, x: Tensor) -> Tensor: 154 return torch.sin(x) 155 156 def get_random_inputs(self) -> Tuple[Tensor]: 157 return (torch.randn(100),) 158 159 160class CompositeDelegateModule(torch.nn.Module): 161 def __init__(self) -> None: 162 super().__init__() 163 164 class DelegateAdd(nn.Module): 165 def __init__(self) -> None: 166 super().__init__() 167 168 def forward(self, x: Tensor, y: Tensor) -> Tensor: 169 return [x + y] 170 171 def get_random_inputs(self) -> Tuple[Tensor, Tensor]: 172 return (torch.randn(1, 3), torch.randn(1, 3)) 173 174 delegated_m = DelegateAdd() 175 edge_ir_m = to_edge( 176 export( 177 delegated_m, 178 delegated_m.get_random_inputs(), 179 ) 180 ) 181 lowered_module = LoweredBackendModule( 182 edge_program=edge_ir_m.exported_program(), 183 backend_id="backend_demo", 184 processed_bytes=bytes("basic_module_add", encoding="utf8"), 185 compile_specs=[], 186 ) 187 self.lowered_module: LoweredBackendModule = lowered_module 188 189 def forward(self, a: exir.Value, b: exir.Value, s: Tensor) -> Tensor: 190 res = self.lowered_module(a, b) 191 res = res[0] * s 192 return res 193 194 def get_random_inputs(self) -> Tuple[Tensor, Tensor, Tensor]: 195 return (torch.randn(1, 3), torch.randn(1, 3), torch.randn(1, 3)) 196 197 198class BatchMatrixMultiplication(nn.Module): 199 def __init__(self, transposed: bool = False) -> None: 200 super().__init__() 201 202 # Whether the last 2 dims (-1, -2) of the input has already been 203 # transposed. If yes, transpose it back before feeding to torch.bmm 204 self.transposed: bool = transposed 205 206 def forward(self, x: Tensor, y: Tensor) -> Tensor: 207 if self.transposed: 208 return torch.bmm(x, y.transpose(-1, -2)) 209 else: 210 return torch.bmm(x, y) 211 212 def extra_repr(self) -> str: 213 return f"transposed={self.transposed}" 214 215 216class TensorSplit(nn.Module): 217 def __init__(self) -> None: 218 super().__init__() 219 220 def forward(self, input: Tensor, sections: int, dim: int = 0) -> List[Tensor]: 221 # pyre-fixme[7]: Expected `List[Tensor]` but got `Tuple[Tensor, ...]`. 222 return torch.tensor_split(input, sections, dim) 223 224 225class TensorSplitWithSizes(nn.Module): 226 def __init__(self) -> None: 227 super().__init__() 228 229 def forward(self, input: Tensor, split_size: int, dim: int = 0) -> List[Tensor]: 230 # pyre-fixme[7]: Expected `List[Tensor]` but got `Tuple[Tensor, ...]`. 231 return torch.split(input, split_size, dim) 232 233 234class Cat(nn.Module): 235 def __init__(self) -> None: 236 super().__init__() 237 238 # def forward(self, tensors, dim=0): 239 def forward(self, *args: Tensor, dim: int) -> Tensor: 240 tensors = args[:-1] 241 return torch.cat(tensors, dim) 242 243 244class FeedForwardBlock(nn.Module): 245 def __init__(self, input_dim: int, hidden_dim: int) -> None: 246 super().__init__() 247 self.input_dim = input_dim 248 self.hidden_dim = hidden_dim 249 250 self.layer_norm = nn.LayerNorm(input_dim) 251 252 self.relu = nn.ReLU() 253 254 self.linear1 = nn.Linear(input_dim, hidden_dim) 255 self.dropout1 = nn.Dropout() 256 257 self.linear2 = nn.Linear(hidden_dim, input_dim) 258 self.dropout2 = nn.Dropout() 259 260 def forward(self, x: Tensor) -> Tensor: 261 # LayerNorm -> Linear -> Dropout -> ReLU -> Linear -> Dropout 262 y = self.layer_norm(x) 263 y = self.linear1(y) 264 y = self.dropout1(y) 265 y = self.relu(y) 266 y = self.linear2(y) 267 y = self.dropout2(y) 268 return y 269 270 271class NoOp(nn.Module): 272 """ 273 NoOp simply passes the input as the output. 274 """ 275 276 def __init__(self) -> None: 277 super().__init__() 278 279 def forward(self, input: Tensor) -> Tensor: 280 return input 281 282 283class MultiLayerPerceptron(nn.Module): 284 def __init__( 285 self, 286 input_dim: int, 287 hidden_dim1: int, 288 hidden_dim2: int, 289 hidden_dim3: int, 290 output_dim: int, 291 ) -> None: 292 super().__init__() 293 self.input_dim = input_dim 294 self.hidden_dim1 = hidden_dim1 295 self.hidden_dim2 = hidden_dim2 296 self.hidden_dim3 = hidden_dim3 297 self.output_dim = output_dim 298 self.layers = nn.Sequential( 299 nn.Linear(input_dim, hidden_dim1), 300 nn.ReLU(), 301 nn.Linear(hidden_dim1, hidden_dim2), 302 nn.ReLU(), 303 nn.Linear(hidden_dim2, hidden_dim3), 304 nn.ReLU(), 305 nn.Linear(hidden_dim3, output_dim), 306 ) 307 308 def forward(self, x: Tensor) -> Tensor: 309 return self.layers(x) 310 311 312class ScaledDotProductAttentionModularized(nn.Module): 313 def __init__( 314 self, 315 embed_dim: int, 316 num_heads: int, 317 dropout_p: float = 0.5, 318 ) -> None: 319 super().__init__() 320 self.embed_dim = embed_dim 321 self.num_heads = num_heads 322 self.dropout_p = dropout_p 323 self.dropout = nn.Dropout(p=dropout_p) 324 325 self.head_dim: int = embed_dim // num_heads 326 self.scaling: float = self.head_dim**-0.5 327 328 self.mul = Mul() 329 self.reshape = Reshape() 330 self.transpose = Transpose() 331 self.bmm = BatchMatrixMultiplication(transposed=False) 332 self.bmm_t = BatchMatrixMultiplication(transposed=True) 333 self.softmax = nn.Softmax(dim=-1) 334 335 def forward( 336 self, 337 q: Tensor, 338 k: Tensor, 339 v: Tensor, 340 ) -> Tensor: 341 # q: (L, B, D) k: (S, B, D) v: (S, B, D) 342 # assert k.shape == v.shape 343 # assert q.dim() == 3 and k.dim() == 3 344 # assert q.size(1) == k.size(1) and q.size(2) == k.size(2) 345 346 L, B, D = q.shape 347 S = k.size(0) 348 # assert D % self.head_dim == 0 349 350 # FIXME(poweic): scaling layer!? 351 # this will break the modular assumption, which makes the following 352 # self.reshape to think it is using some floating inputs q because 353 # id(q) is no longer the same id(q) 354 # This is equiv. to `q = q * self.scaling` 355 q = self.mul(q, self.scaling) 356 357 # Reshape & transpose q from (L, B, D) to (B*H, L, D/H) 358 q = self.reshape(q, (L, B * self.num_heads, self.head_dim)) 359 q = self.transpose(q, 0, 1) 360 361 # Reshape & transpose k from (S, B, D) to (B*H, S, D/H) 362 k = self.reshape(k, (S, B * self.num_heads, self.head_dim)) 363 k = self.transpose(k, 0, 1) 364 365 # Reshape & transpose v from (S, B, D) to (B*H, S, D/H) 366 v = self.reshape(v, (S, B * self.num_heads, self.head_dim)) 367 v = self.transpose(v, 0, 1) 368 369 # bmm((B*H, L, D/H), (B*H, D/H, S)) -> (B*H, L, S). 370 # this is equiv. to `qk = torch.bmm(q, k.transpose(-1, -2))` 371 qk = self.bmm_t(q, k) 372 # assert qk.shape == (B * self.num_heads, L, S) 373 374 softmax_qk = self.softmax(qk) 375 376 softmax_qk = self.dropout(softmax_qk) 377 378 # bmm((B*H, L, S), (B*H, S, D/H)) -> (B*H, L, D/H). 379 # this is equiv. to `attention = torch.bmm(softmax_qk, v)` 380 attention = self.bmm(softmax_qk, v) 381 # assert attention.shape == (B * self.num_heads, L, self.head_dim) 382 383 # Transpose & reshape attention: (B*H, L, D/H) -> (L, B*H, D/H) -> (L, B, D). 384 attention = self.transpose(attention, 0, 1) 385 attention = self.reshape(attention, (L, B, self.embed_dim)) 386 387 return attention 388 389 390# ------------------------------------------------------------------------------ 391# Scaled Dot-Product Attention 392# ------------------------------------------------------------------------------ 393class ScaledDotProductAttention(nn.Module): 394 def __init__( 395 self, 396 embed_dim: int, 397 num_heads: int, 398 dropout: Optional[float] = None, 399 ) -> None: 400 if embed_dim % num_heads: 401 raise ValueError( 402 "embed_dim ({}) must be divisible by num_heads ({})".format( 403 embed_dim, num_heads 404 ) 405 ) 406 407 super().__init__() 408 409 self.embed_dim = embed_dim 410 self.num_heads = num_heads 411 if dropout is not None and dropout > 0.0: 412 self.dropout: nn.Module = nn.Dropout(p=dropout) 413 else: 414 self.dropout = NoOp() 415 416 self.head_dim: int = embed_dim // num_heads 417 self.scaling: float = self.head_dim**-0.5 418 419 def forward( 420 self, 421 q: Tensor, 422 k: Tensor, 423 v: Tensor, 424 padding_mask: Optional[Tensor] = None, 425 attention_mask: Optional[Tensor] = None, 426 ) -> Tensor: 427 # q: (L, B, D) k: (S, B, D) v: (S, B, D) 428 # assert k.shape == v.shape 429 # assert q.dim() == 3 and k.dim() == 3 430 # assert q.size(1) == k.size(1) and q.size(2) == k.size(2) 431 432 L, B, D = q.shape 433 S = k.size(0) 434 # assert D % self.head_dim == 0 435 436 q = q * self.scaling 437 q = q.reshape(L, B * self.num_heads, self.head_dim).transpose( 438 0, 1 439 ) # (B*H, L, D/H) 440 441 k = k.reshape(S, B * self.num_heads, self.head_dim).transpose( 442 0, 1 443 ) # (B*H, S, D/H) 444 445 v = v.reshape(S, B * self.num_heads, self.head_dim).transpose( 446 0, 1 447 ) # (B*H, S, D/H) 448 449 # bmm((B*H, L, D/H), (B*H, D/H, S)) -> (B*H, L, S). 450 qk = torch.bmm(q, k.transpose(1, 2)) 451 # assert qk.shape == (B * self.num_heads, L, S) 452 453 # TODO(cfyeh): figure out if we really need input to be float. 454 softmax_qk = nn.functional.softmax(qk.float(), dim=-1) 455 456 # softmax_qk = self.dropout(softmax_qk) 457 458 # bmm((B*H, L, S), (B*H, S, D/H)) -> (B*H, L, D/H). 459 attention = torch.bmm(softmax_qk, v) 460 # assert attention.shape == (B * self.num_heads, L, self.head_dim) 461 462 # (B*H, L, D/H) -> (L, B*H, D/H) -> (L, B, D). 463 attention = attention.transpose(0, 1).reshape(L, B, self.embed_dim) 464 465 return attention 466 467 468class Emformer(nn.Module): 469 def __init__( 470 self, 471 l_dim: int = 32, 472 m_dim: int = 8, 473 c_dim: int = 8, 474 r_dim: int = 8, 475 input_dim: int = 256, 476 ffn_hidden_dim: int = 512, 477 ) -> None: 478 super().__init__() 479 480 self.l_dim = l_dim 481 self.m_dim = m_dim 482 self.c_dim = c_dim 483 self.r_dim = r_dim 484 485 self.input_dim = input_dim 486 self.ffn_hidden_dim = ffn_hidden_dim 487 488 self.split = TensorSplit() 489 self.elem_add = ElementwiseAdd() 490 491 self.attn = ScaledDotProductAttention( 492 embed_dim=input_dim, 493 num_heads=8, 494 ) 495 496 self.ffn = FeedForwardBlock(input_dim, ffn_hidden_dim) 497 498 self.layer_norm = nn.LayerNorm(input_dim) 499 500 self.linear_k = nn.Linear(self.input_dim, self.input_dim) 501 self.linear_v = nn.Linear(self.input_dim, self.input_dim) 502 self.linear_q = nn.Linear(self.input_dim, self.input_dim) 503 504 def get_random_inputs(self) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: 505 inputs = ( 506 torch.randn(self.m_dim, 1, self.input_dim), 507 torch.randn(self.c_dim, 1, self.input_dim), 508 torch.randn(self.r_dim, 1, self.input_dim), 509 torch.randn(self.l_dim, 1, self.input_dim), 510 torch.randn(self.l_dim, 1, self.input_dim), 511 ) 512 return inputs 513 514 def forward( 515 self, M: Tensor, C: Tensor, R: Tensor, K_L: Tensor, V_L: Tensor 516 ) -> Tensor: 517 """ 518 The Emformer block takes [M_i^n, C_i^n, R_i^n] and [K_{L,i}^n, V_{L,i}^n] 519 as inputs and outputs [C_i^{n+1}, R_i^{n+1}]. 520 See Fig. 1(b) Emformer and equations 6, 7, 8 - 13 in the original paper 521 https://arxiv.org/pdf/2010.10759.pdf 522 523 Ex: 524 - self.input_dim = 525 - L.shape = 30 x 1 x 512 526 - M.shape = 2 x 1 x 512 527 - C.shape = 5 x 1 x 512 528 - R.shape = 1 x 1 x 512 529 """ 530 # Equation 8 531 CR = torch.cat([C, R], 0) 532 CR_normed = self.layer_norm(CR) 533 # C_normed = self.layer_norm(C) 534 # R_normed = self.layer_norm(R) 535 536 # Equation 9 and 10 537 if True: 538 MCR = torch.cat([M, C, R], 0) 539 K_MCR = self.linear_k(MCR) 540 V_MCR = self.linear_v(MCR) 541 542 K_M, K_C, K_R = self.split(K_MCR, 3) 543 V_M, V_C, V_R = self.split(V_MCR, 3) 544 else: 545 K_M, K_C, K_R = self.linear_k(M), self.linear_k(C), self.linear_k(R) 546 V_M, V_C, V_R = self.linear_v(M), self.linear_v(C), self.linear_v(R) 547 548 K = torch.cat([K_M, K_L, K_C, K_R], 0) 549 V = torch.cat([V_M, V_L, V_C, V_R], 0) 550 551 # Equation 11 and 12 552 Q_CR = self.linear_q(CR_normed) 553 Z_CR = self.attn(Q_CR, K, V) 554 Z_CR = self.elem_add(Z_CR, CR) 555 # Q_C = self.linear_q(C_normed) 556 # Q_R = self.linear_q(R_normed) 557 # Z_C = self.attn(Q_C, K, V) 558 # Z_R = self.attn(Q_R, K, V) 559 # Z_C = self.elem_add(Z_C, C) 560 # Z_R = self.elem_add(Z_R, R) 561 562 # Equation 6 563 Z_CR_normed = self.layer_norm(Z_CR) 564 ffn_out = self.ffn(Z_CR_normed) 565 566 # Equation 7 567 output = self.layer_norm(self.elem_add(ffn_out, Z_CR)) 568 569 # m = self.attn( 570 571 return output 572 573 574# List of models that we want to export 575# TODO(angelayi): enable ControlFlowWhile test once we enable functionalization 576MODELS = [ 577 ["basic_sin_max", BasicSinMax()], 578 ["composite_delegate", CompositeDelegateModule()], 579] 580