1# Owner(s): ["module: functorch"] 2 3import torch 4from functorch import make_fx 5from functorch.compile import minifier 6from torch._functorch.compile_utils import get_outputs, get_placeholders 7from torch.testing._internal.common_utils import run_tests, TestCase 8 9 10class TestMinifier(TestCase): 11 def test_has_mul_minifier(self): 12 def failing_f(x, y): 13 y = y / 3 14 x = x + 3 15 x = x * y 16 return x + y 17 18 inps = [torch.randn(3), torch.randn(3)] 19 failing_f = make_fx(failing_f)(*inps) 20 21 def has_mul(fx_g, inps): 22 return torch.ops.aten.mul.Tensor in (i.target for i in fx_g.graph.nodes) 23 24 min_f, inps = minifier(failing_f, inps, has_mul) 25 self.assertEqual(len(min_f.graph.nodes), 4) 26 self.assertEqual(len(inps), 2) 27 28 def test_has_add_mul(self): 29 def failing_f(x): 30 x = x * 3 31 x = x + 5 32 x = x.cos() 33 zero = x - x 34 result = zero / zero 35 result = result + 3 36 return (result * 2,) 37 38 inps = [torch.randn(3)] 39 failing_f = make_fx(failing_f)(*inps) 40 41 def has_nans(fx_g, inps): 42 # Basically, make sure none of the nodes are computing nans 43 for i in inps: 44 if torch.isnan(i).any(): 45 return False 46 return torch.isnan(fx_g(*inps)[0]).any() 47 48 min_f, inps = minifier(failing_f, inps, has_nans) 49 self.assertEqual(len(min_f.graph.nodes), 3) 50 self.assertEqual(len(inps), 1) 51 52 def test_input_returned(self): 53 def f(a, b, c): 54 a = a.sin() 55 c = c.cos() 56 d = a * c 57 return (a, b, c, d) 58 59 inps = [torch.randn(3) for _ in range(3)] 60 61 def inputs_returned(fx_g, inps): 62 inps = set(get_placeholders(fx_g.graph)) 63 outs = set(get_outputs(fx_g.graph)) 64 return len(inps & outs) > 0 65 66 failing_f = make_fx(f)(*inps) 67 min_f, inps = minifier(failing_f, inps, inputs_returned) 68 self.assertEqual(len(min_f.graph.nodes), 2) 69 self.assertEqual(len(inps), 1) 70 71 def test_tup_use(self): 72 def f(a, b): 73 tup = torch.std_mean(a) 74 return (tup[0] + b * tup[1],) 75 76 inps = [torch.randn(3), torch.randn(3)] 77 78 def has_add(fx_g, inps): 79 return torch.ops.aten.add.Tensor in (i.target for i in fx_g.graph.nodes) 80 81 failing_f = make_fx(f)(*inps) 82 min_f, inps = minifier(failing_f, inps, has_add) 83 84 self.assertEqual(len(min_f.graph.nodes), 4) 85 self.assertEqual(len(inps), 2) 86 87 def test_module(self): 88 class MockModule(torch.nn.Module): 89 def __init__(self) -> None: 90 super().__init__() 91 self.relu = torch.nn.ReLU() 92 93 def forward(self, x): 94 y = self.relu(x) 95 zero = y - y 96 result = zero / zero 97 result = result + 3 98 return result 99 100 mod = MockModule() 101 failing_f = torch.fx.symbolic_trace(mod) 102 103 inps = [torch.randn(3)] 104 105 def pass_checker(fx_g, inps): 106 # Basically, make sure none of the inputs are nans 107 for i in inps: 108 if torch.isnan(i).any(): 109 return False 110 return torch.isnan(fx_g(*inps)[0]).any() 111 112 min_f, inps = minifier(failing_f, inps, pass_checker) 113 assert len(min_f.graph.nodes) == 3 114 assert len(inps) == 1 115 116 117if __name__ == "__main__": 118 run_tests() 119