1# Owner(s): ["module: fx"] 2 3import copy 4import unittest 5from collections import defaultdict 6 7import torch 8import torch.fx as fx 9from torch._dynamo.source import LocalSource 10from torch.fx.experimental.shape_inference.infer_shape import infer_shape 11from torch.fx.experimental.shape_inference.infer_symbol_values import ( 12 infer_symbol_values, 13) 14from torch.fx.experimental.symbolic_shapes import DimDynamic, ShapeEnv 15 16 17class TestShapeInference(unittest.TestCase): 18 def test_infer_symbol_values(self): 19 def mksym(shape_env, value, source, dynamic_dim) -> None: 20 return shape_env.create_symintnode( 21 shape_env.create_symbol( 22 value, 23 source=source, 24 dynamic_dim=dynamic_dim, 25 ), 26 hint=value, 27 source=source, 28 ) 29 30 shape_env = ShapeEnv() 31 N = 8 32 sample = {f"s{i}": 2 for i in range(N)} 33 init_symints = [ 34 mksym(shape_env, v, LocalSource(k), DimDynamic.DYNAMIC) 35 for k, v in sample.items() 36 ] 37 symints = copy.deepcopy(init_symints) 38 symbol_to_idx_dict = {f"s{i}": i for i in range(N)} 39 padding_constraints = defaultdict(list) 40 41 # prepare constraints strings 42 constraints = [] 43 constraints.append( 44 "The size of tensor a (s1) must match the size of tensor b (1773) at non-singleton dimension 1)" 45 ) 46 constraints.append( 47 "Expected size for first two dimensions of batch2 tensor to be: [s0, (s2//2) + 12] but got: [s0, 120]." 48 ) 49 constraints.append("shape '[s0, -1, 32]' is invalid for input of size s0*s3") 50 constraints.append( 51 "a and b must have same reduction dim, but got [32*s0, s3] X [20, 15]." 52 ) 53 constraints.append( 54 "a and b must have same reduction dim, but got [s0, s4 + 1568] X [5728, 1024]." 55 ) 56 constraints.append( 57 "Expected size for first two dimensions of batch2 tensor to be: [s0, 40] but got: [s0, s5]." 58 ) 59 constraints.append( 60 "shape '[s0, -1, 32]' is invalid for input of size s0*s6 + 1344*s0" 61 ) 62 constraints.append( 63 "shape '[-1, 47]' is invalid for input of size 32*s0*s6 + 1344*s0" 64 ) 65 constraints.append( 66 "Expected size for first two dimensions of batch2 tensor to be: [s0, 47*s6] but got: [s0*s6, 47]." 67 ) 68 constraints.append("Split sizes add up to 4258 but got the tensor's size of s7") 69 70 for constraint in constraints: 71 infer_symbol_values( 72 symints, 73 init_symints, 74 symbol_to_idx_dict, 75 padding_constraints, 76 constraint, 77 ) 78 79 self.assertEqual(symints[1], 1773) 80 self.assertEqual(symints[2], 216) 81 self.assertEqual(symints[3], 640) 82 self.assertEqual(symints[4], 4160) 83 self.assertEqual(symints[5], 40) 84 self.assertEqual(symints[6], 160) 85 self.assertEqual(symints[7], 4258) 86 87 def test_infer_shape(self): 88 class TestModule(torch.nn.Module): 89 def __init__(self) -> None: 90 super().__init__() 91 self.w_1 = torch.empty([256, 328]) 92 self.b_1 = torch.empty([256]) 93 self.w_2 = torch.empty([328, 256]) 94 self.b_2 = torch.empty([328]) 95 96 def forward(self, x): 97 l_1 = torch.nn.functional.linear(x, self.w_1, bias=self.b_1) 98 s_1 = torch.sigmoid(l_1) 99 l_2 = torch.nn.functional.linear(s_1, self.w_2, bias=self.b_2) 100 t_1 = torch.tanh(l_2) 101 return t_1 102 103 def generate_graph_module(model): 104 gm = fx.symbolic_trace(model) 105 return gm 106 107 m = TestModule() 108 gm = generate_graph_module(m) 109 input_tensors = [torch.randn(1, 1)] 110 infer_shape(gm, input_tensors) 111