1# Owner(s): ["oncall: fx"] 2 3import itertools 4 5import torch 6from torch.fx.experimental.proxy_tensor import make_fx 7from torch.fx.graph_module import GraphModule 8from torch.fx.passes.dialect.common.cse_pass import CSEPass 9from torch.testing._internal.common_utils import ( 10 instantiate_parametrized_tests, 11 parametrize, 12 run_tests, 13 TestCase, 14) 15 16 17def FactoryFunctionCall(x, device): 18 y = torch.full(x.shape, 3, device=device) 19 z = torch.add(y, x) 20 return z 21 22 23def TorchTensorCall(x): 24 y = torch.tensor(3) 25 return x + y 26 27 28def TakeList(x): 29 z = torch.cat([x, x]) 30 return z 31 32 33def ReturnList(x): 34 a = torch.arange(10).reshape(5, 2) 35 z = torch.split(a, [1, 4]) 36 return z 37 38 39def Mutation(x): 40 y = x + 2 41 y.add_(1) 42 return x + y 43 44 45def MutationInput(x): 46 x.add_(1) 47 y = x + 2 48 return x + y 49 50 51def MutationFactory(x, device): 52 y = torch.full(x.shape, 3, device=device) 53 y.add_(1) 54 return x + y 55 56 57def MutationTorchTensorCall(x): 58 y = torch.tensor(3) 59 y.add_(1) 60 return x + y 61 62 63def MutationMetadata(x): 64 x.resize_(2) 65 return x 66 67 68Passes = [CSEPass] 69Test_Cases = [ 70 TakeList, 71 ReturnList, 72 Mutation, 73 MutationInput, 74 MutationMetadata, 75 MutationTorchTensorCall, 76] 77Factory_Test_Cases = [FactoryFunctionCall, MutationFactory] 78Devices = ["cpu"] 79if torch.cuda.is_available(): 80 Devices.append("cuda") 81 82 83def name_fn(common_pass, f, device): 84 """Names parameterized test cases.""" 85 return f"{type(common_pass()).__name__}_{f.__name__}_{device}" 86 87 88@instantiate_parametrized_tests 89class TestCommonPass(TestCase): 90 @parametrize( 91 "common_pass,f,device", itertools.product(Passes, Test_Cases, Devices), name_fn 92 ) 93 def test_correctness(self, common_pass, f, device): 94 inp = torch.randn(10, device=device) 95 96 traced_m = make_fx(f)(inp) 97 P = common_pass() 98 99 res = P(traced_m) 100 modified_m = res.graph_module 101 assert isinstance(modified_m, GraphModule) 102 103 inp_copy = inp.clone() 104 expected = f(inp) 105 result = modified_m(inp_copy) 106 107 self.assertEqual(result, expected) 108 109 @parametrize( 110 "common_pass,f,device", 111 itertools.product(Passes, Factory_Test_Cases, Devices), 112 name_fn, 113 ) 114 def test_correctness_factory(self, common_pass, f, device): 115 inp = torch.randn(10, device=device) 116 traced_m = make_fx(f)(inp, device) 117 P = common_pass() 118 119 res = P(traced_m) 120 modified_m = res.graph_module 121 assert isinstance(modified_m, GraphModule) 122 123 inp_copy = inp.clone() 124 expected = f(inp, device) 125 result = modified_m(inp_copy, device) 126 127 self.assertEqual(result, expected) 128 129 130if __name__ == "__main__": 131 run_tests() 132