# Owner(s): ["module: fx.passes"] from dataclasses import dataclass import operator import logging import sys import torch from torch.fx._symbolic_trace import symbolic_trace from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner from torch.fx.passes.operator_support import OperatorSupport from torch.fx.passes.utils.fuser_utils import fuse_by_partitions from torch.fx.passes.utils.matcher_utils import SubgraphMatcher from torch.testing._internal.common_utils import run_tests, parametrize, instantiate_parametrized_tests from torch.testing._internal.jit_utils import JitTestCase logging.basicConfig(level=logging.WARNING) logger = logging.getLogger(__name__) class TestModule(torch.nn.Module): def __init__(self) -> None: super().__init__() self.linear = torch.nn.Linear(4, 4) self.linear2 = torch.nn.Linear(4, 4) self.param = torch.nn.Parameter(torch.rand(4, 4)) def forward(self, a, b, c): add = a + b linear_1 = self.linear(add) add_1 = add + c add_2 = add_1 + self.param add_3 = add_1 + linear_1 add_4 = add_2 + add_3 linear_2 = self.linear2(add_4) add_5 = linear_2 + add_4 add_6 = add_5 + a relu = add_6.relu() return add_4, add_6, relu class TestDeepModule(torch.nn.Module): def __init__(self) -> None: super().__init__() self.linear = torch.nn.Linear(4, 4) def forward(self, a, b, c): o = a + b o = o + 1.0 # testing to avoid DFS uses in passes. Since Python has max recursion depth. for _ in range(sys.getrecursionlimit() + 1): o = o - c return o class TestPartitionFunctions: @staticmethod def forward1(a, b, c): add = a + b add_1 = add + b add_2 = add_1 + c relu_1 = add_2.relu() add_3 = add_1 + add_2 add_4 = add_1 + relu_1 + add_3 relu_2 = add_4.relu() add_5 = relu_2 + add_4 add_6 = add_5 + add_4 return add_4, add_6 @staticmethod def forward2(a, b, _): add = a + b add_1 = add + b relu_1 = add_1.relu() # blocked by this add_3 = add_1 + relu_1 add_4 = add_1 + add_3 return add_4, add_1 @staticmethod def forward3(a, b, c): add = a + b add_1 = a + c add_2 = b + c return add, add_1, add_2 @staticmethod def forward4(a, b, c): add = a + b add_1 = a + c add_2 = b + c return torch.where(add > 0, add_1, add_2) @staticmethod def forward5(a, b, c): # add should be fused right branch, as left branch is not supported add = a + 1 # left branch relu = add.relu() # right branch add_1 = add + 2 return relu, add_1 @staticmethod def forward6(a, b, c): # add should have its own partition, as neither branchs are supported add = a + 1 # left branch relu = add.relu() # right branch relu_1 = add.relu() return relu, relu_1 @staticmethod def forward7(a, b, c): # both branches are supported, all adds should be fused together add = a + 1 # left branch add_1 = add + 2 # right branch is larger add_2 = add + 1 add_3 = add_2 + 1 return add_3, add_1 @staticmethod def forward8(a, b, c): # both branches are in the same partition, add should join the same partition add = a + 1 # left branch add_1 = add + 2 # right branch add_2 = add + 1 # left and right branch merges add_3 = add_2 + add_1 return add_3 @staticmethod def forward9(a, b, c): add = a + 1 # branch 1 add_1 = add + 1 # branch 2 add_2 = add + 1 # branch_3 add_3 = add + 1 out = torch.stack([add_1, add_2, add_3]) return out @staticmethod def forward10(a, b, c): add = a + 1 # branch 1 add_1 = add + 1 # branch 2 add_2 = add + 1 # branch 3: depends on branch 2 add_3 = add + add_2 out = torch.stack([add_1, add_2, add_3]) return out @staticmethod def forward11(a, b, c): add = a + 1 # branch 1 add_1 = add.relu() # branch 2 depends on branch 1 add_2 = add + add_1 # branch 3 add_3 = add.relu() out = torch.stack([add_1, add_2, add_3]) return out @staticmethod def forward12(a, b, c): b0 = a + 1.0 c0 = a + 1.5 x0 = b0.relu() x1 = c0.relu() b1 = b0 + x1 c1 = c0 + 1.2 # c2 has dependency on x0 & b0, when we merge {c0, c1, c2} # this dependency should be updated to the fusion group and reflected # on the decision to not fuse b0 & b1, which forms a cyclic dependency in # the new graph c2 = x0 + c0 return b1, c2 @staticmethod def forward13(a, b, c): a0, a1, a2, a3 = a.split(1, 0) b1 = a0 + b c1 = a1 + c return b1 + c1 @staticmethod def forward14(a, b, c): a0, a1 = torch.ops.aten.std_mean(a) out = a0 + 1.0 return out @staticmethod def forward15(a, b, c): a0 = torch.ops.aten.view(a, [2, 2]) a1 = torch.ops.aten.permute(a0, [1, 0]) a2 = a1 + 1.0 a3 = torch.ops.aten.permute(a2, [1, 0]) a4 = a3 + 1.0 a5 = torch.ops.aten.permute(a4, [1, 0]) return torch.ops.aten.permute(a5, [1, 0]) @staticmethod def forward16(a, b, c): a0 = a - 1.0 a1 = torch.ops.aten.view(a0, [2, 2]) a2 = torch.ops.aten.permute(a1, [1, 0]) a3 = a2 + 1.0 a4 = torch.ops.aten.permute(a3, [1, 0]) a5 = a4 + 1.0 a6 = torch.ops.aten.permute(a5, [1, 0]) a7 = torch.ops.aten.permute(a6, [1, 0]) return a7 - 1.0 @staticmethod def forward17(a, b, c, d, e, f): a0 = a + b a1 = c + d a2 = e + f return a0, a1, a2 @staticmethod def forward18(a, b, c): a0, a1 = torch.ops.aten.var_mean(a) return a0 # A mock OperatorSupport class, where only operator.add is supported class MockOperatorSupport(OperatorSupport): def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: return (node.op == "call_function" and node.target in {operator.add, operator.getitem, torch.ops.aten.view, torch.ops.aten.permute, torch.ops.aten.std_mean}) @instantiate_parametrized_tests class TestFXGraphPasses(JitTestCase): @parametrize("fn, expected_partition, bookend_non_compute_pass", [ (TestPartitionFunctions.forward1, [["add_7", "add_6"], ["add_5", "add_4", "add_3"], ["add_2", "add_1", "add"]], False), (TestPartitionFunctions.forward2, [["add_3", "add_2"], ["add_1", "add"]], False), # 1 horizontal fusion with common producer (TestPartitionFunctions.forward3, [["add_2", "add_1", "add"]], False), (TestPartitionFunctions.forward4, [["add_2", "add_1", "add"]], False), # 2 branches cases (TestPartitionFunctions.forward5, [["add_1", "add"]], False), (TestPartitionFunctions.forward6, [["add"]], False), (TestPartitionFunctions.forward7, [["add_3", "add_2", "add", "add_1"]], False), (TestPartitionFunctions.forward8, [["add_3", "add_2", "add", "add_1"]], False), # 3 branch cases (TestPartitionFunctions.forward9, [['add_3', 'add_2', 'add_1', 'add']], False), (TestPartitionFunctions.forward10, [['add_3', 'add_2', 'add', 'add_1']], False), (TestPartitionFunctions.forward11, [['add_1'], ['add']], False), # 4 not necessarily the only partition, just to verify that there's no cyclic dependency after partition (TestPartitionFunctions.forward12, [["add_2", "add_3", "add_4"], ["add", "add_1"]], False), # 5 getitem special case (TestPartitionFunctions.forward13, [["add_2", "add_1", "add"]], False), (TestPartitionFunctions.forward14, [["add", "std_mean", "getitem", "getitem_1"]], False), # 6 bookend non_compute pass (TestPartitionFunctions.forward15, [["permute_1", "add_1", "add"]], True), (TestPartitionFunctions.forward15, [['add_1', 'add', 'permute_1', 'view', 'permute_2', 'permute_3', 'permute']], False), (TestPartitionFunctions.forward16, [["permute_1", "add_1", "add"]], True), (TestPartitionFunctions.forward16, [['add_1', 'add', 'permute_1', 'view', 'permute_2', 'permute_3', 'permute']], False), # should be empty partition, not a partiton with empty nodes (TestPartitionFunctions.forward18, [], False), ]) def test_partitioner(self, fn, expected_partition, bookend_non_compute_pass): traced = symbolic_trace(fn) non_compute_ops = [] if bookend_non_compute_pass: non_compute_ops = ["torch.ops.aten.view", "torch.ops.aten.permute"] supported_ops = MockOperatorSupport() partitioner = CapabilityBasedPartitioner(traced, supported_ops, allows_single_node_partition=True, non_compute_ops=non_compute_ops) partitions = partitioner.propose_partitions() if bookend_non_compute_pass: partitioner.remove_bookend_non_compute_ops(partitions) partitions_name = [[node.name for node in partition.nodes] for partition in partitions] assert len(partitions_name) == len(expected_partition) for i in range(len(partitions_name)): assert set(partitions_name[i]) == set(expected_partition[i]) fused_graph = partitioner.fuse_partitions(partitions) a, b, c = torch.rand(4), torch.rand(4), torch.rand(4) expected = fn(a, b, c) result = fused_graph(a, b, c) torch.testing.assert_close(expected, result) @parametrize("fn, expected_partition", [ (TestPartitionFunctions.forward17, [['add', 'add_1', 'add_2']]), ]) def test_partitioner_independent_output(self, fn, expected_partition): traced = symbolic_trace(fn) supported_ops = MockOperatorSupport() partitioner = CapabilityBasedPartitioner(traced, supported_ops, allows_single_node_partition=True) partitions = partitioner.propose_partitions() partitions_name = [[node.name for node in partition.nodes] for partition in partitions] assert len(partitions_name) == len(expected_partition) for i in range(len(partitions_name)): assert set(partitions_name[i]) == set(expected_partition[i]) fused_graph = partitioner.fuse_partitions(partitions) a, b, c, d, e, f = torch.rand(4), torch.rand(4), torch.rand(4), torch.rand(4), torch.rand(4), torch.rand(4) expected = fn(a, b, c, d, e, f) result = fused_graph(a, b, c, d, e, f) torch.testing.assert_close(expected, result) @parametrize("partition", [ [['add', 'add_1'], ['add_5', 'add_6']], [['add', 'add_1', 'add_2']], # vertical fusion [['add_2', 'add_3']], # horizontal fusion [['add_3', 'add_4']], [['add_6', 'add_5']], # arbitray node order [['add_4', 'add_1', 'add_3', 'add_2']], # arbitray node order [['add_5', 'add_6'], ['add_1', 'add_2', 'add_3', 'add_4']], # arbitray partition order [['add_5', 'linear2']], # includes call_function + call_module node [['add_6', 'relu']], # includes call_function + call_module node [['param', 'add_2']], # includes get_attr + call_module nodes [['param', 'add_1', 'linear']], # includes get_attr + call_function + call_module nodes [["add", "linear", "add_1", "param", "add_2", "add_3", "add_4", "linear2", "add_5", "add_6", "relu"]], # full graph ]) def test_fuser_util(self, partition): m = TestModule() gm = symbolic_trace(m) nodes_by_name = {node.name : node for node in gm.graph.nodes} partitions = [] for node_names in partition: partitions.append([nodes_by_name[name] for name in node_names]) fused_graph = fuse_by_partitions(gm, partitions) a, b, c = torch.rand(4), torch.rand(4), torch.rand(4) expected = m(a, b, c) result = fused_graph(a, b, c) torch.testing.assert_close(expected, result) @parametrize("partition", [ [['add', 'add_1'], ['add_1', 'add_5', 'add_6']], # add_1 exists in multiple partitions [['add', 'add_1', 'add_3']], # invalid partition: circular dependency [['add_4', 'add_5']], # invalid partition: circular dependency [['relu', 'add_5']], # invalid partition: circular dependency ]) def test_fuser_util_xfail(self, partition): m = TestModule() gm = symbolic_trace(m) nodes_by_name = {node.name : node for node in gm.graph.nodes} partitions = [] for node_names in partition: partitions.append([nodes_by_name[name] for name in node_names]) with self.assertRaises(Exception): fuse_by_partitions(gm, partitions) def test_fuser_pass_deep_model(self): m = TestDeepModule() traced = symbolic_trace(m) supported_ops = MockOperatorSupport() partitioner = CapabilityBasedPartitioner(traced, supported_ops, allows_single_node_partition=True) partitions = partitioner.propose_partitions() @dataclass class TestCase: match_output: bool match_placeholder: bool num_matches: int remove_overlapping_matches: bool = True class SingleNodePattern: @staticmethod def forward(x): val = torch.neg(x) return torch.add(val, val) @staticmethod def pattern(a): return torch.neg(a) test_cases = [ # match_output, match_placeholder, num_matches TestCase(False, False, 1), TestCase(True, False, 0), TestCase(False, True, 1), TestCase(True, True, 0) ] class SimplePattern: @staticmethod def forward(x, w1, w2): m1 = torch.cat([w1, w2]).sum() m2 = torch.cat([w2, w1]).sum() m3 = torch.cat([m1, m2]).sum() return x + torch.max(m1) + torch.max(m2) + m3 @staticmethod def pattern(a, b): return torch.cat([a, b]).sum() test_cases = [ # match_output, match_placeholder, num_matches TestCase(False, False, 3), TestCase(True, False, 0), TestCase(False, True, 2), TestCase(True, True, 0) ] class SimpleFullGraphMatching: @staticmethod def forward(x): a = torch.neg(x) return torch.add(a, a) @staticmethod def pattern(x): a = torch.neg(x) return torch.add(a, a) test_cases = [ # match_output, match_placeholder, num_matches TestCase(False, False, 1), TestCase(True, False, 1), TestCase(False, True, 1), TestCase(True, True, 1) ] class DiamondShapePatternTestCase: @staticmethod def forward(x): a = torch.neg(x) a = a.relu() left = a.sigmoid() right = a.relu() out = left + right return out @staticmethod def pattern(a): a = a.relu() left = a.sigmoid() right = a.relu() out = left + right return out test_cases = [ # match_output, match_placeholder, num_matches TestCase(False, False, 1), TestCase(True, False, 1), TestCase(False, True, 0), TestCase(True, True, 0) ] class NonFullyContainedMatches: @staticmethod def forward(x, w1, w2, b1, b2): # fully contained matched subgraph m1 = torch.cat([w1, w2]) m2 = torch.cat([x, b2]) t0 = torch.addmm(b1, m1, m2.t()) t0_sum = torch.sum(t0) # use of t0 is not leaking # leaking matched subgraph, m3 is leaked m3 = torch.cat([w1, w2]) m4 = torch.cat([x, b2]) t1 = torch.addmm(b1, m3, m4.t()) m3_sum = torch.sum(m3) return t0_sum, m3_sum @staticmethod def pattern(x, w1, w2, b1, b2): m1 = torch.cat([w1, w2]) m2 = torch.cat([x, b2]) return torch.addmm(b1, m1, m2.t()) test_cases = [ # match_output, match_placeholder, num_matches TestCase(False, False, 1), TestCase(True, False, 0), TestCase(False, True, 1), # leaked used of placeholder is not leaking ] class ChainRepeatedPattern: @staticmethod def forward(x): x = torch.sigmoid(x) x = torch.sigmoid(x) x = torch.sigmoid(x) return torch.sigmoid(x) @staticmethod def pattern(x): return torch.sigmoid(torch.sigmoid(x)) test_cases = [ # match_output, match_placeholder, num_matches TestCase(False, False, 3, remove_overlapping_matches=False), TestCase(False, False, 2, remove_overlapping_matches=True), TestCase(True, False, 1), TestCase(False, True, 1), TestCase(True, True, 0) ] class QuantizationModel: @staticmethod def forward(x): x += 3 x = x.dequantize() x = torch.sigmoid(x) x = x.to(torch.float16) return x @staticmethod def pattern(x): x = x.dequantize() x = torch.sigmoid(x) x = x.to(torch.float16) return x test_cases = [ # match_output, match_placeholder, num_matches TestCase(False, False, 1), TestCase(True, False, 1), TestCase(False, True, 0), TestCase(True, True, 0) ] class MultipleOutputsWithDependency: @staticmethod def forward(x): y = x.relu() z = y.sigmoid() return z, y @staticmethod def pattern(a): b = a.relu() c = b.sigmoid() return b, c # outputs have data dependency test_cases = [ # match_output, match_placeholder, num_matches TestCase(False, False, 1), TestCase(True, False, 0), TestCase(False, True, 1), TestCase(True, True, 0) ] class MultipleOutputsWithoutDependency: @staticmethod def forward(x): x = x + 1 # target subgraph to match x = x.relu() z = x.sum() y = x.sigmoid() out = y.sigmoid() + z.sum() return out @staticmethod def pattern(a): a = a.relu() b = a.sigmoid() c = a.sum() return b, c test_cases = [ # match_output, match_placeholder, num_matches TestCase(False, False, 1), TestCase(True, False, 0), TestCase(False, True, 0), TestCase(True, True, 0) ] class MultipleOutputsMultipleOverlappingMatches: @staticmethod def forward(x): x = x + 1 # target subgraph to match x = x.relu() z = x.sum() z1 = x.sum() y = x.sigmoid() y1 = x.sigmoid() return z + z1 + y + y1 @staticmethod def pattern(a): a = a.relu() b = a.sigmoid() c = a.sum() return a, b, c test_cases = [ # match_output, match_placeholder, num_matches TestCase(False, False, 4, remove_overlapping_matches=False), TestCase(False, False, 1, remove_overlapping_matches=True), ] class MultipleOutputsMultipleNonOverlappingMatches: @staticmethod def forward(x): x = x + 1 # target subgraph to match x = x.relu() z = x.sum() y = x.sigmoid() x = x.relu() z1 = x.sum() y1 = x.sigmoid() return z + z1 + y + y1 @staticmethod def pattern(a): a = a.relu() b = a.sigmoid() c = a.sum() return b, c test_cases = [ # match_output, match_placeholder, num_matches TestCase(False, False, 1), ] class MultipleOutputsIdenticalAnchor: @staticmethod def forward(x): x = x + 1 # target subgraph to match x = x.relu() y = x.sigmoid() y1 = x.sigmoid() return y, y1 @staticmethod def pattern(a): a = a.relu() b = a.sigmoid() b1 = a.sigmoid() return b, b1 test_cases = [ # match_output, match_placeholder, num_matches # (False, False, 2), # FIXME: currently still matches to 2, should fix to 1 TestCase(True, False, 1), TestCase(False, True, 0), ] class MultipleOutputsHorizontalPattern: @staticmethod def forward(x): x = x + 1 # target subgraph to match y1 = x.relu() y2 = x.sigmoid() return y1, y2 @staticmethod def pattern(a): b1 = a.relu() b2 = a.sigmoid() return b1, b2 test_cases = [ # match_output, match_placeholder, num_matches TestCase(False, False, 1), TestCase(True, False, 1), TestCase(False, True, 0), TestCase(True, True, 0) ] class MultiOutputWithWithInvalidMatches: @staticmethod def forward(x): res0 = torch.nn.functional.linear(x, torch.rand(3, 3)) res1 = torch.sigmoid(res0) res2 = res0 * res1 res3 = torch.sum(res2, dim=1) return res3 @staticmethod def pattern(a, b, c): lin_res = torch.nn.functional.linear(a, b) mul_res = lin_res * c return lin_res, mul_res test_cases = [ # match_output, match_placeholder, num_matches TestCase(False, False, 0), TestCase(True, False, 0), TestCase(False, True, 0), ] class QuantizationFp8Pattern: @classmethod def setup(cls): cls.quantization = torch.library.Library("fp8_quantization", "DEF") # noqa: TOR901 cls.quantization.define("quantize_per_tensor_affine_fp8(Tensor self, int dtype, float scale) -> Tensor") cls.quantization.define("dequantize_per_tensor_affine_fp8(Tensor self, int dtype, float scale) -> Tensor") @classmethod def tearDown(cls): del cls.quantization @staticmethod def forward(self, arg0_1, arg1_1): qt = torch.ops.fp8_quantization _scale_0 = self._scale_0 quantize_per_tensor_affine_fp8 = qt.quantize_per_tensor_affine_fp8(arg0_1, 0, _scale_0) dequantize_per_tensor_affine_fp8 = qt.dequantize_per_tensor_affine_fp8(quantize_per_tensor_affine_fp8, 0, _scale_0) _scale_1 = self._scale_0 quantize_per_tensor_affine_fp8_1 = qt.quantize_per_tensor_affine_fp8(arg1_1, 0, _scale_1) dequantize_per_tensor_affine_fp8_1 = qt.dequantize_per_tensor_affine_fp8(quantize_per_tensor_affine_fp8_1, 0, _scale_1) add = torch.ops.aten.add.Tensor(dequantize_per_tensor_affine_fp8, dequantize_per_tensor_affine_fp8_1) _scale_2 = self._scale_0 quantize_per_tensor_affine_fp8_2 = qt.quantize_per_tensor_affine_fp8(add, 0, _scale_2) dequantize_per_tensor_affine_fp8_2 = qt.dequantize_per_tensor_affine_fp8(quantize_per_tensor_affine_fp8_2, 0, _scale_2) return dequantize_per_tensor_affine_fp8_2 @staticmethod def pattern(a, a_dtype, a_scale, b, b_dtype, b_scale, out_scale): qt = torch.ops.fp8_quantization a = qt.dequantize_per_tensor_affine_fp8(a, a_dtype, a_scale) b = qt.dequantize_per_tensor_affine_fp8(b, b_dtype, b_scale) output = torch.ops.aten.add.Tensor(a, b) qt.dequantize_per_tensor_affine_fp8 output = qt.quantize_per_tensor_affine_fp8(output, a_dtype, out_scale) return output test_cases = [ # match_output, match_placeholder, num_matches TestCase(False, False, 1), ] class NoAnchorFound: # This test case is for pattern where no matching anchor is found in the target graph # `anchor` is the starting point of the pattern matching, it's usually the boundary returning nodes @staticmethod def forward(x): x = x + 1 return x @staticmethod def pattern(a): b1 = a.relu() return b1 test_cases = [ # match_output, match_placeholder, num_matches TestCase(False, False, 0), TestCase(True, False, 0), TestCase(False, True, 0), TestCase(True, True, 0) ] @instantiate_parametrized_tests class TestFXMatcherUtils(JitTestCase): @parametrize("test_model", [ SingleNodePattern, SimplePattern, SimpleFullGraphMatching, DiamondShapePatternTestCase, NonFullyContainedMatches, ChainRepeatedPattern, QuantizationModel, MultipleOutputsWithDependency, MultipleOutputsWithoutDependency, MultipleOutputsMultipleOverlappingMatches, MultipleOutputsMultipleNonOverlappingMatches, MultipleOutputsIdenticalAnchor, MultipleOutputsHorizontalPattern, MultiOutputWithWithInvalidMatches, QuantizationFp8Pattern, NoAnchorFound, ]) def test_subgraph_matcher(self, test_model): setup = getattr(test_model, "setup", None) if callable(setup): setup() traced = symbolic_trace(test_model.forward) pattern_traced = symbolic_trace(test_model.pattern) for test_case in test_model.test_cases: matcher = SubgraphMatcher(pattern_traced.graph, match_output=test_case.match_output, match_placeholder=test_case.match_placeholder, remove_overlapping_matches=test_case.remove_overlapping_matches) matches = matcher.match(traced.graph) assert len(matches) == test_case.num_matches for match in matches: for node in pattern_traced.graph.nodes: if not test_case.match_placeholder and node.op == "placeholder": continue if not test_case.match_output and node.op == "output": continue assert node in match.nodes_map tearDown = getattr(test_model, "tearDown", None) if callable(setup): tearDown() if __name__ == "__main__": run_tests()