# Owner(s): ["module: unknown"] import unittest from typing import Dict, Optional import numpy as np import torch from torch import nn from torch.testing._internal.common_utils import TestCase, run_tests from torch.testing._internal.static_module import StaticModule from typing import List def linear_shim( input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None ) -> torch.Tensor: output = input.matmul(weight.t()) if bias is not None: output += bias ret = output return ret torch.nn.functional.linear = linear_shim class MultiHeadAttentionLayer(nn.Module): def __init__(self, hid_dim, n_heads, dropout, device): super().__init__() assert hid_dim % n_heads == 0 self.hid_dim = hid_dim self.n_heads = n_heads self.head_dim = hid_dim // n_heads self.fc_q = nn.Linear(hid_dim, hid_dim) self.fc_k = nn.Linear(hid_dim, hid_dim) self.fc_v = nn.Linear(hid_dim, hid_dim) self.fc_o = nn.Linear(hid_dim, hid_dim) # self.dropout = nn.Dropout(dropout) self.scale = torch.sqrt(torch.FloatTensor([self.head_dim])).to(device) def forward(self, query, key, value, mask): batch_size = query.shape[0] Q = self.fc_q(query) K = self.fc_k(key) V = self.fc_v(value) Q = Q.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3) K = K.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3) V = V.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3) energy = torch.matmul(Q, K.permute(0, 1, 3, 2)) / self.scale # energy = energy.masked_fill(mask == 0, -1e10) attention = torch.softmax(energy, dim=-1) # x = torch.matmul(self.dropout(attention), V) x = torch.matmul(attention, V) x = x.permute(0, 2, 1, 3).contiguous() x = x.view(batch_size, -1, self.hid_dim) x = self.fc_o(x) return x, attention # Taken from https://github.com/facebookresearch/dlrm/blob/master/dlrm_s_pytorch.py def create_mlp(ln, sigmoid_layer): layers = nn.ModuleList() for i in range(0, len(ln) - 1): n = ln[i] m = ln[i + 1] LL = nn.Linear(int(n), int(m), bias=True) mean = 0.0 # std_dev = np.sqrt(variance) std_dev = np.sqrt(2 / (m + n)) # np.sqrt(1 / m) # np.sqrt(1 / n) W = np.random.normal(mean, std_dev, size=(m, n)).astype(np.float32) std_dev = np.sqrt(1 / m) # np.sqrt(2 / (m + 1)) bt = np.random.normal(mean, std_dev, size=m).astype(np.float32) LL.weight.data = torch.tensor(W, requires_grad=True) LL.bias.data = torch.tensor(bt, requires_grad=True) layers.append(LL) if i == sigmoid_layer: layers.append(nn.Sigmoid()) else: layers.append(nn.ReLU()) with torch.no_grad(): s = torch.jit.script(torch.nn.Sequential(*layers)) s.eval() return s def trivial_graph(a, b, c): s = torch.tensor([[3, 3], [3, 3]]) return a + b * c + s def elementwise_square_addition(input1, input2): return input1 * input1 + input2 * input2 def fork_wait_graph1(input1, input2): fut = torch.jit.fork(elementwise_square_addition, input1, input2) return torch.jit.wait(fut) def fork_wait_graph2(input1, input2): fut = torch.jit.fork(loop_graph, input1, input2, 5) return torch.jit.wait(fut) """ graph with multiple fork/wait operations :param input: torch.tensor input to forked subgraph :param iters: number of future/wait pairs to be created """ def fork_wait_graph3(input, iters: int): futures : List[torch.jit.Future[torch.Tensor]] = [] for _ in range(iters): futures.append(torch.jit.fork(torch.neg, input)) results = [] for future in futures: results.append(torch.jit.wait(future)) return torch.sum(torch.stack(results)) """ graph with multi-level fork/wait operations :param input: torch.tensor input to forked subgraph :param num_forks: number of top level forks :param num_child_forks: number of child forks per parent fork """ def fork_wait_graph4(input, num_forks: int, num_child_forks: int): futures : List[torch.jit.Future[torch.Tensor]] = [] for _ in range(num_forks): futures.append(torch.jit.fork(fork_wait_graph3, input, num_child_forks)) results = [] for future in futures: results.append(torch.jit.wait(future)) return torch.sum(torch.stack(results)) def add_tensor(input1, input2): return input1 + input2 def fork_wait_graph_exception(input1, input2): fut = torch.jit.fork(add_tensor, input1, input2) return torch.jit.wait(fut) def loop_graph(a, b, iters: int): c = a + b * 2 for i in range(iters): c = c + b c *= 2 c -= a return c def output_graph(a, b, c, iters: int): s = torch.tensor([[3, 3], [3, 3]]) k = a + b * c + s d: Dict[int, torch.Tensor] = {} for i in range(iters): d[i] = k + i return d class SubModule(nn.Module): def __init__(self) -> None: super().__init__() self.a = 11 self.b = 2 def forward(self, x): return self.a + self.b + x class SubModule2(nn.Module): def __init__(self) -> None: super().__init__() self.a = 12 self.b = 2 def forward(self, x): self.b = 30 return self.a + self.b + x class TestModule(nn.Module): def __init__(self) -> None: super().__init__() self.sub1 = SubModule() self.sub2 = SubModule2() self.a = 3 self.b = 4 def forward(self, x): self.b = 20 return self.sub1(x) + self.a + self.b + self.sub2(x) class TestStaticModule(TestCase): """ Test Case: To test simple fork/wait operation in a graph fork is called on simple addition operation on input tensors """ def test_fork_wait_1(self): inp1 = torch.ones(5, 5) inp2 = torch.randn(5, 5) torch_graph = torch.jit.script(fork_wait_graph1) output_ref = torch_graph(inp1, inp2) static_runtime_module = StaticModule(torch_graph) output_test = static_runtime_module(inp1, inp2) torch.testing.assert_close(output_test, output_ref) """ Test Case: To test simple fork/wait operation with StaticRuntime runAsync API returning future """ def test_fork_wait_1_async(self): inp1 = torch.ones(5, 5) inp2 = torch.randn(5, 5) torch_graph = torch.jit.script(fork_wait_graph1) output_ref = torch_graph(inp1, inp2) static_runtime_module = StaticModule(torch_graph) output_test = static_runtime_module.runAsync((inp1, inp2), {}) output_test.wait() torch.testing.assert_close(output_test.value(), output_ref) """ Test Case: To test fork/wait operation in a graph on a loop subgraph performing mix of operations """ def test_fork_wait_2(self): inp1 = torch.randn(5, 5) inp2 = torch.randn(5, 5) torch_graph = torch.jit.script(fork_wait_graph2) output_ref = torch_graph(inp1, inp2) static_runtime_module = StaticModule(torch_graph) output_test = static_runtime_module(inp1, inp2) torch.testing.assert_close(output_test, output_ref) """ Test Case: To test fork/wait operation on a loop subgraph with StaticRuntime runAsync API returning future """ def test_fork_wait_2_async(self): inp1 = torch.randn(5, 5) inp2 = torch.randn(5, 5) torch_graph = torch.jit.script(fork_wait_graph2) output_ref = torch_graph(inp1, inp2) static_runtime_module = StaticModule(torch_graph) output_test = static_runtime_module.runAsync((inp1, inp2), {}) output_test.wait() torch.testing.assert_close(output_test.value(), output_ref) """ Test Case: To test fork/wait operation in a graph on having multiple fork/wait operations """ def test_fork_wait_3(self): input = torch.ones(3, 3) num_forks = 10 torch_graph = torch.jit.script(fork_wait_graph3) output_ref = torch_graph(input, num_forks) static_runtime_module = StaticModule(torch_graph) output_test = static_runtime_module(input, num_forks) torch.testing.assert_close(output_test, output_ref) """ Test Case: To test fork/wait operation in a graph with multiple fork/wait operations on runAsync API returning future """ def test_fork_wait_3_async(self): input = torch.ones(3, 3) num_forks = 10 torch_graph = torch.jit.script(fork_wait_graph3) output_ref = torch_graph(input, num_forks) static_runtime_module = StaticModule(torch_graph) output_test = static_runtime_module.runAsync((input, num_forks), {}) output_test.wait() torch.testing.assert_close(output_test.value(), output_ref) """ Test Case: To test fork/wait operation in a graph on multiple nested fork/wait operations """ @unittest.skip("Broken test: https://github.com/pytorch/pytorch/issues/109782") def test_fork_wait_4(self): input = torch.ones(3, 3) num_forks = 10 num_child_forks = 10 torch_graph = torch.jit.script(fork_wait_graph4) static_runtime_module = StaticModule(torch_graph) output_ref = torch_graph(input, num_forks, num_child_forks) output_test = static_runtime_module(input, num_forks, num_child_forks) torch.testing.assert_close(output_test, output_ref) """ Test Case: To test fork/wait operation in a graph with multiple nested fork/wait operations on runAsync API returning future """ @unittest.skip("Broken test: https://github.com/pytorch/pytorch/issues/109782") def test_fork_wait_4_async(self): input = torch.ones(3, 3) num_forks = 10 num_child_forks = 10 torch_graph = torch.jit.script(fork_wait_graph4) static_runtime_module = StaticModule(torch_graph) output_ref = torch_graph(input, num_forks, num_child_forks) output_test = static_runtime_module.runAsync( (input, num_forks, num_child_forks), {}) output_test.wait() torch.testing.assert_close(output_test.value(), output_ref) """ Test Case: To test exception handling in fork/wait operation. Add.Tensor op is called for tensors with non-matching dims on the forked subgraph and the exception raised by subgraph is set on future returned by prim::fork to parent graph. Returned exception is checked for substring expected_error_msg as declared below """ def test_fork_wait_exception(self): # incompatible tensors for add due to shape mismatch input1 = torch.randn(4, 7) input2 = torch.randn(4, 5) torch_graph = torch.jit.script(fork_wait_graph_exception) try: static_runtime_module = StaticModule(torch_graph) output_test = static_runtime_module(input1, input2) except Exception as error: expected_error_msg = ( "The size of tensor a (7) must match the size " "of tensor b (5) at non-singleton dimension 1" ) # test fails if error does not contain expected substr if str(error).find(expected_error_msg) == -1: raise RuntimeError( "Tried execution of add.Tensors with incompatible shape. " "Exception raised by forked runtime execution does " f'not contain expected substring: "{expected_error_msg}"' ) from error """ Test Case: To test exception handling in fork/wait operation with runAsync API. Add.Tensor op is called for tensors with non-matching dims on the forked subgraph and the exception raised by subgraph is set on future returned by prim::fork to parent graph. Returned exception is checked for substring expected_error_msg as declared below """ def test_fork_wait_exception_async(self): # incompatible tensors for add due to shape mismatch input1 = torch.randn(4, 7) input2 = torch.randn(4, 5) torch_graph = torch.jit.script(fork_wait_graph_exception) try: static_runtime_module = StaticModule(torch_graph) output_test = static_runtime_module.runAsync( (input1, input2), {}) except Exception as error: expected_error_msg = ( "The size of tensor a (7) must match the size " "of tensor b (5) at non-singleton dimension 1" ) # test fails if error does not contain expected substr if str(error).find(expected_error_msg) == -1: raise RuntimeError( "Tried execution of add.Tensors with incompatible shape. " "Exception raised by forked runtime execution does " f'not contain expected substring: "{expected_error_msg}"' ) from error def test_multihead_attention_layer(self): HID_DIM = 256 QUERY_LEN = 8 BATCH_SIZE = 128 LAYERS = 3 HEADS = 8 DROPOUT = 0.1 device = torch.device("cpu") attention = MultiHeadAttentionLayer(HID_DIM, HEADS, DROPOUT, device).to(device) with torch.no_grad(): src = torch.randn(BATCH_SIZE, QUERY_LEN, HID_DIM).to(device) src_mask = (src > 0)[:, :, 0].unsqueeze(1).unsqueeze(2).to(device) attention.eval() attention = torch.jit.script(attention) attention.eval() o_ref = attention(src, src, src, src_mask) attention_a = StaticModule(attention) o_test = attention_a(src, src, src, src_mask) o_test_kw = attention_a(src, src, value=src, mask=src_mask) for a, b in zip(o_ref, o_test): torch.testing.assert_close(a, b) for a, b in zip(o_ref, o_test_kw): torch.testing.assert_close(a, b) def test_multihead_attention_layer_benchmark(self): HID_DIM = 256 QUERY_LEN = 8 BATCH_SIZE = 128 LAYERS = 3 HEADS = 8 DROPOUT = 0.1 device = torch.device("cpu") attention = MultiHeadAttentionLayer(HID_DIM, HEADS, DROPOUT, device).to(device) with torch.no_grad(): src = torch.randn(BATCH_SIZE, QUERY_LEN, HID_DIM).to(device) src_mask = (src > 0)[:, :, 0].unsqueeze(1).unsqueeze(2).to(device) attention.eval() attention = torch.jit.script(attention) attention_a = StaticModule(attention) attention_a.benchmark([src, src, src, src_mask], {}, 2, 2) metrics = attention_a.benchmark_individual_ops( [src, src, src, src_mask], {}, 2, 2 ) def test_mlp(self): # Arguments taken from benchmark script, ./bench/dlrm_s_benchmark.sh ln_bot = [512, 512, 64] sigmoid_bot = -1 ln_top = [100, 1024, 1024, 1024, 1] sigmoid_top = 3 bot_l = create_mlp(ln_bot, sigmoid_bot) bot_l_acc = StaticModule(bot_l) top_l = create_mlp(ln_top, sigmoid_top) top_l_acc = StaticModule(top_l) with torch.no_grad(): bot_inp = torch.randn(2048, 512) # torch.Size([2048, 512]) top_inp = torch.randn(2048, 100) # torch.Size([2048, 100]) ref_bot = bot_l(bot_inp) acc_bot = bot_l_acc(bot_inp) torch.testing.assert_close(acc_bot, ref_bot) ref_top = top_l(top_inp) acc_top = top_l_acc(top_inp) torch.testing.assert_close(acc_top, ref_top) for _ in range(5): with torch.no_grad(): bot_inp = torch.randn(2048, 512) # torch.Size([2048, 512]) top_inp = torch.randn(2048, 100) # torch.Size([2048, 100]) ref_bot = bot_l(bot_inp) acc_bot = bot_l_acc(bot_inp) torch.testing.assert_close(acc_bot, ref_bot) ref_top = top_l(top_inp) acc_top = top_l_acc(top_inp) torch.testing.assert_close(acc_top, ref_top) def test_trivial_graph(self): s = torch.full((2, 2), 2) tg = torch.jit.script(trivial_graph) o_ref = tg(s, s, s) tg_a = StaticModule(tg) o_test = tg_a(s, s, s) torch.testing.assert_close(o_ref, o_test) def test_leaky_relu(self): s = torch.randn(5, 5) tg = torch.jit.script(nn.LeakyReLU(0.1)) o_ref = tg(s) tg_a = StaticModule(tg) o_test = tg_a(s) torch.testing.assert_close(o_ref, o_test) def test_attr(self): """ TorchScript IR of TestModule() after freezing: graph(%self : __torch__.test_static_runtime.___torch_mangle_0.TestModule, %x.1 : Tensor): %18 : int = prim::Constant[value=30]() %30 : int = prim::Constant[value=13]() %3 : int = prim::Constant[value=20]() %2 : int = prim::Constant[value=1]() %self.sub2.a : int = prim::Constant[value=12]() %self.a : int = prim::Constant[value=3]() = prim::SetAttr[name="b"](%self, %3) %17 : Tensor = aten::add(%x.1, %30, %2) %7 : Tensor = aten::add(%17, %self.a, %2) %b.1 : int = prim::GetAttr[name="b"](%self) %9 : Tensor = aten::add(%7, %b.1, %2) %sub2 : __torch__.test_static_runtime.___torch_mangle_2.SubModule2 = prim::GetAttr[name="sub2"](%self) = prim::SetAttr[name="b"](%sub2, %18) %b : int = prim::GetAttr[name="b"](%sub2) %22 : int = aten::add(%self.sub2.a, %b) %23 : Tensor = aten::add(%x.1, %22, %2) %12 : Tensor = aten::add(%9, %23, %2) return (%12) """ # test prim::SetAttr and prim::GetAttr impl in Static Runtime m = TestModule() m.eval() input = torch.randn(2, 2) output_s = m.forward(input) ms = torch.jit.script(m) sm = StaticModule(ms) output_sm = sm(input) torch.testing.assert_close(output_s, output_sm) sm.benchmark([input], {}, 2, 2) sm.benchmark_individual_ops([input], {}, 2, 2) sm.benchmark([], {"x": input}, 2, 2) sm.benchmark_individual_ops([], {"x": input}, 2, 2) @unittest.skip("Temporarily disabled") def test_fusion_trivial_graph(self): s = torch.full((2, 2), 2) tg = torch.jit.script(trivial_graph) o_ref = tg(s, s, s) torch._C._fuse_to_static_module(tg.graph) assert "StaticSubgraph" in str(tg.graph) o_test = tg(s, s, s) torch.testing.assert_close(o_ref, o_test) @unittest.skip("Temporarily disabled") def test_fusion_multihead_attention_layer(self): HID_DIM = 256 QUERY_LEN = 8 BATCH_SIZE = 128 LAYERS = 3 HEADS = 8 DROPOUT = 0.1 device = torch.device("cpu") attention = MultiHeadAttentionLayer(HID_DIM, HEADS, DROPOUT, device).to(device) with torch.no_grad(): src = torch.randn(BATCH_SIZE, QUERY_LEN, HID_DIM).to(device) src_mask = (src > 0)[:, :, 0].unsqueeze(1).unsqueeze(2).to(device) attention.eval() attention = torch.jit.script(attention) attention.eval() o_ref = attention(src, src, src, src_mask) torch._C._fuse_to_static_module(attention._c) o_test = attention(src, src, src, src_mask) for a, b in zip(o_ref, o_test): torch.testing.assert_close(a, b) @unittest.skip("Temporarily disabled") def test_fusion_loop(self): a = torch.randn(5, 5) b = torch.randn(5, 5) c = 4 lg = torch.jit.script(loop_graph) o_ref = lg(a, b, c) torch._C._fuse_to_static_module(lg.graph) assert "StaticSubgraph" in str(lg.graph) o_test = lg(a, b, c) torch.testing.assert_close(o_ref, o_test) @unittest.skip("Temporarily disabled") def test_fusion_outputs(self): a = torch.randn(2, 2) b = torch.randn(2, 2) c = 4 og = torch.jit.script(output_graph) o_ref = og(a, b, b, c) torch._C._fuse_to_static_module(og.graph) assert "StaticSubgraph" in str(og.graph) o_test = og(a, b, b, c) for i in o_ref.keys(): torch.testing.assert_close(o_ref[i], o_test[i]) def test_create_object(self): class Foo: # noqa: B903 def __init__(self, x: torch.Tensor) -> None: self.x = x class Mod(torch.nn.Module): def __init__(self) -> None: super().__init__() def forward(self, y: torch.Tensor) -> torch.Tensor: foo = Foo(y) return y * foo.x mod = torch.jit.script(Mod()).eval() y = torch.randn((1, )) expected = mod(y) static_mod = StaticModule(torch.jit.freeze(mod)) actual = static_mod(y) self.assertEqual(expected, actual) if __name__ == "__main__": run_tests()