xref: /aosp_15_r20/external/pytorch/test/inductor/test_group_batch_fusion.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: inductor"]
2
3import collections
4import unittest
5from typing import List
6
7import torch
8import torch._inductor
9import torch._inductor.fx_passes.group_batch_fusion
10from torch._dynamo.utils import counters, optimus_scuba_log
11from torch._inductor.test_case import run_tests, TestCase
12from torch.testing._internal.inductor_utils import HAS_CUDA
13
14
15try:
16    # importing this will register fbgemm lowerings for inductor
17    import deeplearning.fbgemm.fbgemm_gpu.fb.inductor_lowerings  # noqa: F401
18
19    has_fbgemm = True
20except Exception:
21    has_fbgemm = False
22
23requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda")
24
25
26class TestHighwaySelfGating(torch.nn.Module):
27    def __init__(
28        self,
29        d_model: int,
30        size: int,
31        device="cuda",
32    ) -> None:
33        super().__init__()
34        self.size = size
35        self.device = device
36        self.gating_proj = torch.nn.Linear(d_model, d_model).to(self.device)
37        self.transform_proj = torch.nn.Linear(d_model, d_model).to(self.device)
38        self.gating_func = torch.nn.Sigmoid().to(self.device)
39
40        self.d_model = d_model
41
42    def forward(
43        self,
44        inputs: List[torch.Tensor],
45    ) -> torch.Tensor:
46        results = []
47        for i in range(self.size):
48            x = inputs[i]
49            gating_proj = self.gating_proj(x)
50            transform_proj = self.transform_proj(x)
51            x = gating_proj * self.gating_func(transform_proj)
52            results.append(x)
53
54        return torch.cat(results, dim=-1)
55
56
57class MyModule(torch.nn.Module):
58    def __init__(self, z: int, has_bias: bool, device="cuda") -> None:
59        super().__init__()
60        self.z = z
61        self.device = device
62        self.seq_len = 10
63        self.seq1 = [
64            torch.nn.Linear(z, z, has_bias).to(self.device) for _ in range(self.seq_len)
65        ]
66        self.seq2 = [
67            torch.nn.Linear(z, z, has_bias).to(self.device) for _ in range(self.seq_len)
68        ]
69        self.seq3 = [
70            torch.nn.Linear(z, z, has_bias).to(self.device) for _ in range(self.seq_len)
71        ]
72
73    def forward(self, x: torch.Tensor) -> torch.Tensor:
74        x1 = [x + 0.1 * i for i in range(self.seq_len)]
75        x2 = [self.seq1[i](x1[i]) for i in range(self.seq_len)]
76        x3 = [x2[i] - 0.1 * i for i in range(self.seq_len)]
77        x4 = [x1[i] for i in range(3)] + [x3[i] for i in range(3, self.seq_len)]
78        x5 = [self.seq2[i](x4[i]) for i in range(self.seq_len)]
79        x6 = [x5[i] + 0.1 * (self.seq_len - i) for i in range(self.seq_len)]
80        x7 = (
81            [x1[i] for i in range(4)]
82            + [x3[i] for i in range(6, 8)]
83            + [x6[i] for i in range(4)]
84        )
85        x8 = [self.seq3[i](x7[i]) for i in range(self.seq_len)]
86        x9 = torch.cat(x8, dim=1)
87        return x9
88
89
90class MyModule2(torch.nn.Module):
91    def __init__(self) -> None:
92        super().__init__()
93        self.linear0 = torch.nn.Linear(6, 8)
94        self.linear1 = torch.nn.Linear(8, 8)
95        self.linear2 = torch.nn.Linear(10, 8)
96        self.linear3 = torch.nn.Linear(6, 8)
97        self.linear4 = torch.nn.Linear(8, 8)
98        self.linear5 = torch.nn.Linear(10, 8)
99        self.bn0 = torch.nn.BatchNorm1d(8)
100        self.bn1 = torch.nn.BatchNorm1d(8)
101        self.bn2 = torch.nn.BatchNorm1d(8)
102
103    def forward(self, x: torch.Tensor) -> torch.Tensor:
104        t = torch.split(x, [6, 8, 10], dim=1)
105        a0 = self.bn0(self.linear0(t[0] + 0.1))
106        a1 = self.bn1(self.linear1(t[1] + 0.2))
107        a2 = self.bn2(self.linear2(t[2] + 0.3))
108        a3 = self.linear3(torch.sin(t[0]))
109        a4 = self.linear4(torch.cos(t[1]))
110        a5 = self.linear5(torch.sin(t[2] * 0.5))
111
112        b = torch.cat([a0, a1, a2, a3, a4, a5])
113        return torch.sigmoid(b)
114
115
116class MyModule3(torch.nn.Module):
117    def __init__(self, device, has_weight=True, has_bias=True):
118        super().__init__()
119        self.device = device
120        self.scale0 = torch.nn.ParameterList(
121            [torch.nn.Parameter(torch.randn(10)) for _ in range(5)]
122        ).to(self.device)
123        self.bias0 = torch.nn.ParameterList(
124            [torch.nn.Parameter(torch.randn(10)) for _ in range(5)]
125        ).to(self.device)
126        self.scale1 = (
127            torch.nn.ParameterList(
128                [torch.nn.Parameter(torch.randn(5, 10)) for _ in range(5)]
129            ).to(self.device)
130            if has_weight
131            else [None for _ in range(5)]
132        )
133        self.bias1 = (
134            torch.nn.ParameterList(
135                [torch.nn.Parameter(torch.randn(5, 10)) for _ in range(5)]
136            ).to(self.device)
137            if has_bias
138            else [None for _ in range(5)]
139        )
140
141    def forward(self, x):
142        l1_out = torch.split(x.to(self.device), 10, dim=2)
143        post_l1 = [
144            torch.nn.functional.layer_norm(
145                l1_out[i], (10,), weight=self.scale0[i], bias=self.bias0[i]
146            )
147            for i in range(len(l1_out))
148        ]
149        l1_out = torch.cat(post_l1, dim=2)
150
151        l2_out = torch.split(l1_out, 10, dim=2)
152        post_l2 = [
153            torch.nn.functional.layer_norm(
154                l2_out[i], (5, 10), weight=self.scale1[i], bias=self.bias1[i]
155            )
156            for i in range(len(l2_out))
157        ]
158
159        return torch.cat(post_l2, dim=2)
160
161
162class MyModule4(torch.nn.Module):
163    def __init__(self, z, device, has_bias):
164        super().__init__()
165        self.z = z
166        self.device = device
167        self.has_bias = has_bias
168        self.seq_len = 10
169        self.weights1 = [
170            torch.nn.Parameter(torch.randn(z - i % 5, z)).to(self.device)
171            for i in range(self.seq_len)
172        ]
173        self.weights2 = [
174            torch.nn.Parameter(torch.randn(z - i % 5, z)).to(self.device)
175            for i in range(self.seq_len)
176        ]
177
178        if has_bias:
179            self.biases1 = [
180                torch.nn.Parameter(torch.randn(z - i % 5)).to(self.device)
181                for i in range(self.seq_len)
182            ]
183            self.biases2 = [
184                torch.nn.Parameter(torch.randn(z - i % 5)).to(self.device)
185                for i in range(self.seq_len)
186            ]
187
188    def forward(self, x):
189        x = x + 1.2
190        x1 = [
191            torch.nn.functional.linear(
192                x, self.weights1[i], self.biases1[i] if self.has_bias else None
193            )
194            for i in range(self.seq_len)
195        ]
196        x2 = torch.cat(x1, dim=1)
197        x3 = torch.split(x2, 10, dim=1)
198        x4 = torch.cat(x3)
199        x5 = [
200            torch.nn.functional.linear(
201                x4, self.weights2[i], self.biases2[i] if self.has_bias else None
202            )
203            for i in range(self.seq_len)
204        ]
205        x6 = torch.cat(x5, dim=1)
206        return torch.sigmoid(x6)
207
208
209class MyModule5(torch.nn.Module):
210    def __init__(self, device, has_bias=True):
211        super().__init__()
212        self.device = device
213
214        self.weights = torch.nn.ParameterList(
215            [torch.nn.Parameter(torch.randn(50, 100)).to(self.device) for _ in range(5)]
216        )
217
218        self.biases = (
219            ([torch.nn.Parameter(torch.randn(50)).to(self.device) for _ in range(5)])
220            if has_bias
221            else [None for _ in range(5)]
222        )
223
224    def forward(self, x):
225        l1_out = torch.split(x.to(self.device), 100, dim=1)
226        l1_linear = [
227            torch.nn.functional.linear(l1_out[i], self.weights[i], self.biases[i])
228            for i in range(len(l1_out))
229        ]
230        l1_out = torch.cat(l1_linear, dim=1)
231        return torch.sin(l1_out)
232
233
234class TestPoitwiseOps(torch.nn.Module):
235    def __init__(self, device, has_bias=True):
236        super().__init__()
237        self.device = device
238
239    def forward(self, x):
240        inputs = torch.split(x.to(self.device), 500, dim=1)
241        x_split = torch.split(inputs[0].to(self.device), 50, dim=1)
242        y_split = torch.split(inputs[1].to(self.device), 50, dim=1)
243        tanh_1 = [torch.tanh(x_split[i]) for i in range(len(x_split))]
244        tanh_2 = [torch.tanh(y_split[i]) for i in range(len(y_split))]
245        sigmoid_1 = [torch.sigmoid(tanh_1[i]) for i in range(len(tanh_1))]
246        sigmoid_2 = [torch.sigmoid(tanh_2[i]) for i in range(len(tanh_2))]
247        relu_1 = [torch.nn.functional.relu(sigmoid_1[i]) for i in range(len(sigmoid_1))]
248        relu_2 = [torch.nn.functional.relu(sigmoid_2[i]) for i in range(len(sigmoid_2))]
249        add = [torch.add(relu_1[i], relu_2[i]) for i in range(len(relu_1))]
250        mul = [torch.mul(add[i], add[i]) for i in range(len(add))]
251        sub = [torch.sub(mul[i], mul[i]) for i in range(len(mul))]
252        div = [torch.div(sub[i], sub[i]) for i in range(len(sub))]
253        return torch.cat(div, dim=1)
254
255
256class TestPoitwiseOpsPostGrad(torch.nn.Module):
257    def __init__(self, device):
258        super().__init__()
259        self.device = device
260
261    def forward(self, x):
262        inputs = torch.ops.aten.split(x.to(self.device), 500, dim=1)
263        x_split = torch.ops.aten.split(inputs[0].to(self.device), 50, dim=1)
264        y_split = torch.ops.aten.split(inputs[1].to(self.device), 50, dim=1)
265        tanh_1 = [torch.ops.aten.tanh(x_split[i]) for i in range(len(x_split))]
266        tanh_2 = [torch.ops.aten.tanh(y_split[i]) for i in range(len(y_split))]
267        sigmoid_1 = [torch.ops.aten.sigmoid(tanh_1[i]) for i in range(len(tanh_1))]
268        sigmoid_2 = [torch.ops.aten.sigmoid(tanh_2[i]) for i in range(len(tanh_2))]
269        relu_1 = [torch.ops.aten.relu(sigmoid_1[i]) for i in range(len(sigmoid_1))]
270        relu_2 = [torch.ops.aten.relu(sigmoid_2[i]) for i in range(len(sigmoid_2))]
271        add = [torch.ops.aten.add(relu_1[i], relu_2[i]) for i in range(len(relu_1))]
272        return torch.cat(add, dim=1)
273
274
275@requires_cuda
276@torch._inductor.config.patch(
277    pre_grad_fusion_options={
278        "batch_linear": {},
279        "batch_linear_lhs": {},
280        "batch_layernorm": {},
281        "batch_tanh": {},
282        "batch_relu": {},
283        "batch_sigmoid": {},
284    },
285    post_grad_fusion_options={
286        "batch_aten_add": {},
287        "batch_aten_mul": {},
288        "batch_aten_sub": {},
289        "batch_aten_div": {},
290        "group_linear": {"require_fbgemm": True},
291    },
292)
293class TestGroupBatchFusion(TestCase):
294    def compare_dict_tensors(self, ref_dict, res_dict, rtol=1e-3, atol=1e-3):
295        if len(set(ref_dict.keys())) != len(set(res_dict.keys())):
296            return False
297        for key1 in ref_dict.keys():
298            key2 = "_orig_mod." + key1
299            assert key2 in res_dict, f"{key1} does not exist in traced module"
300            if not torch.allclose(ref_dict[key1], res_dict[key2], rtol=rtol, atol=atol):
301                return False
302        return True
303
304    def compare_pred(self, module, traced, input, rtol=1e-3, atol=1e-3):
305        ref = module(*input)
306        res = traced(*input)
307        self.assertEqual(ref, res, rtol=rtol, atol=atol)
308
309    def compare_parameters(self, module, traced, rtol=1e-3, atol=1e-3):
310        ref_params = dict(module.named_parameters())
311        res_params = dict(traced.named_parameters())
312        self.assertTrue(self.compare_dict_tensors(ref_params, res_params, rtol, atol))
313
314    def compare_gradients(self, module, traced, rtol=1e-3, atol=1e-3):
315        ref_grad = {key: param.grad for key, param in module.named_parameters()}
316        res_grad = {key: param.grad for key, param in traced.named_parameters()}
317        self.assertTrue(
318            self.compare_dict_tensors(ref_grad, res_grad, rtol=rtol, atol=atol)
319        )
320
321    @unittest.skipIf(not has_fbgemm, "requires fbgemm")
322    def test_group_linear_fusion(self):
323        z = 10
324        for has_bias in [True, False]:
325            counters.clear()
326            module = MyModule(z, has_bias).to("cuda")
327            input = [torch.randn(z, z, device="cuda")]
328            traced = torch.compile(module)
329            ref = module(*input)
330            res = traced(*input)
331            self.compare_pred(module, traced, input)
332            self.assertEqual(
333                counters["inductor"]["group_linear"],
334                2,
335            )
336            self.assertNotIn("group_batch_fusion_pre_grad", optimus_scuba_log)
337            ref.sum().backward()
338            res.sum().backward()
339            self.compare_parameters(module, traced)
340            self.compare_gradients(module, traced)
341            self.assertEqual(
342                counters["inductor"]["group_linear"],
343                4,
344            )
345            self.assertEqual(
346                counters["inductor"]["batch_aten_add"],
347                3,
348            )
349            self.assertIn("GroupLinearFusion", optimus_scuba_log)
350            counters.clear()
351
352    @unittest.skipIf(not has_fbgemm, "requires fbgemm")
353    def test_group_linear_fusion_different_shapes(self):
354        counters.clear()
355        module = MyModule2().eval().to("cuda")
356        input = [torch.rand(4, 24, device="cuda")]
357        traced = torch.compile(module)
358        ref = module(*input)
359        res = traced(*input)
360        self.compare_pred(module, traced, input)
361        self.assertEqual(
362            counters["inductor"]["group_linear"],
363            1,
364        )
365        self.assertEqual(
366            counters["inductor"]["batch_fusion"],
367            0,
368        )
369        ref.sum().backward()
370        res.sum().backward()
371        self.compare_parameters(module, traced)
372        self.compare_gradients(module, traced)
373        self.assertEqual(
374            counters["inductor"]["group_linear"],
375            2,
376        )
377        self.assertEqual(
378            counters["inductor"]["batch_aten_mul"],
379            1,
380        )
381        counters.clear()
382
383    def test_batch_layer_norm_fusion(self):
384        for has_weight in [True, False]:
385            for has_bias in [True, False]:
386                counters.clear()
387                module = MyModule3("cuda", has_weight, has_bias).to("cuda")
388                input = [torch.randn(2, 5, 50, device="cuda")]
389                traced = torch.compile(module)
390                ref = module(*input)
391                res = traced(*input)
392                self.compare_pred(module, traced, input)
393                self.assertEqual(counters["inductor"]["batch_layernorm"], 2)
394                ref.sum().backward()
395                res.sum().backward()
396                self.compare_parameters(module, traced, rtol=1e-8, atol=1e-8)
397                self.compare_gradients(module, traced, rtol=1e-8, atol=1e-8)
398                counters.clear()
399
400    def test_batch_linear_lhs_fusion(self):
401        z = 10
402        for has_bias in [True, False]:
403            counters.clear()
404            module = MyModule4(z, "cuda", has_bias)
405            input = [torch.randn(20, z, device="cuda")]
406            traced = torch.compile(module)
407            ref = module(*input)
408            res = traced(*input)
409            self.compare_pred(module, traced, input)
410            self.assertEqual(counters["inductor"]["batch_linear_lhs"], 2)
411            ref.sum().backward()
412            res.sum().backward()
413            self.compare_parameters(module, traced, rtol=1e-8, atol=1e-8)
414            self.compare_gradients(module, traced, rtol=1e-8, atol=1e-8)
415            counters.clear()
416
417    def test_batch_linear_pre_grad_fusion(self):
418        for has_bias in [True, False]:
419            counters.clear()
420            module = MyModule5("cuda", has_bias)
421            input = [torch.randn(50, 500, device="cuda")]
422            traced = torch.compile(module)
423            ref = module(*input)
424            res = traced(*input)
425            self.compare_pred(module, traced, input)
426            self.assertEqual(counters["inductor"]["batch_linear"], 1)
427            ref.sum().backward()
428            res.sum().backward()
429            self.compare_parameters(module, traced, rtol=1e-8, atol=1e-8)
430            self.compare_gradients(module, traced, rtol=1e-8, atol=1e-8)
431            counters.clear()
432
433    def test_pointwise_op_fusion(self):
434        counters.clear()
435        module = TestPoitwiseOps("cuda")
436        input = [torch.randn(50, 1000, requires_grad=True, device="cuda")]
437        traced = torch.compile(module)
438        ref = module(*input)
439        res = traced(*input)
440        self.compare_pred(module, traced, input)
441        self.assertEqual(counters["inductor"]["batch_tanh"], 1)
442        self.assertEqual(counters["inductor"]["batch_relu"], 1)
443        self.assertEqual(counters["inductor"]["batch_sigmoid"], 1)
444        self.assertEqual(counters["inductor"]["batch_aten_add"], 1)
445        self.assertEqual(counters["inductor"]["batch_aten_mul"], 1)
446        self.assertEqual(counters["inductor"]["batch_aten_sub"], 1)
447        self.assertEqual(counters["inductor"]["batch_aten_div"], 1)
448        ref.sum().backward()
449        res.sum().backward()
450        self.compare_parameters(module, traced, rtol=1e-8, atol=1e-8)
451        self.compare_gradients(module, traced, rtol=1e-8, atol=1e-8)
452        counters.clear()
453
454    @requires_cuda
455    @torch._inductor.config.patch(
456        pre_grad_fusion_options={},
457        post_grad_fusion_options={
458            "batch_aten_relu": {},
459            "batch_aten_sigmoid": {},
460            "batch_aten_tanh": {},
461            "unbind_stack_aten_pass": {},
462        },
463    )
464    def test_pointwise_op_fusion_post_grad(self):
465        counters.clear()
466        module = TestPoitwiseOpsPostGrad("cuda")
467        input = [torch.randn(50, 1000, requires_grad=True, device="cuda")]
468        traced = torch.compile(module)
469        ref = module(*input)
470        res = traced(*input)
471        self.compare_pred(module, traced, input)
472        self.assertEqual(counters["inductor"]["batch_aten_tanh"], 1)
473        self.assertEqual(counters["inductor"]["batch_aten_relu"], 1)
474        self.assertEqual(counters["inductor"]["batch_aten_sigmoid"], 1)
475        self.assertEqual(counters["inductor"]["unbind_stack_aten_pass"], 2)
476        ref.sum().backward()
477        res.sum().backward()
478        self.compare_parameters(module, traced, rtol=1e-8, atol=1e-8)
479        self.compare_gradients(module, traced, rtol=1e-8, atol=1e-8)
480        counters.clear()
481
482    @requires_cuda
483    @torch._inductor.config.patch(
484        pre_grad_fusion_options={},
485        post_grad_fusion_options={
486            "batch_linear_post_grad": {
487                "shape_broadcast_batch_linear": True,
488                "fuse_nodes_with_same_users": True,
489            },
490            "batch_aten_mul": {"fuse_nodes_with_same_parent": False},
491            "batch_aten_sigmoid": {"fuse_nodes_with_same_parent": True},
492            "batch_aten_add": {"fuse_nodes_with_same_parent": True},
493            "normalization_aten_pass": {},
494            "unbind_stack_aten_pass": {},
495        },
496    )
497    def test_gate_fusion_post_grad(self):
498        counters.clear()
499        size = 20
500        module = TestHighwaySelfGating(d_model=10, size=size)
501        input = [
502            [
503                torch.randn(10, 10, requires_grad=True, device="cuda")
504                for i in range(size)
505            ]
506        ]
507        traced = torch.compile(module)
508        ref = module(*input)
509        res = traced(*input)
510        self.compare_pred(module, traced, input)
511        self.assertEqual(counters["inductor"]["batch_linear_post_grad"], 2)
512        self.assertEqual(counters["inductor"]["batch_aten_sigmoid"], 1)
513        self.assertEqual(counters["inductor"]["batch_aten_mul"], 1)
514        self.assertEqual(counters["inductor"]["batch_aten_add"], 2)
515        self.assertEqual(counters["inductor"]["normalization_aten_pass"], 1)
516        self.assertEqual(counters["inductor"]["unbind_stack_aten_pass"], 5)
517        ref.sum().backward()
518        res.sum().backward()
519        self.compare_parameters(module, traced, rtol=1e-8, atol=1e-8)
520        self.compare_gradients(module, traced, rtol=1e-8, atol=1e-8)
521        counters.clear()
522
523
524class TestBMMFusionModule(torch.nn.Module):
525    def __init__(self) -> None:
526        super().__init__()
527        self.my_modules = torch.nn.ModuleList()
528        for _ in range(10):
529            self.my_modules.append(torch.nn.Linear(10, 10))
530
531    def forward(self, inputs):
532        output = None
533        for linear, input in zip(self.my_modules, inputs):
534            if output is None:
535                output = linear(input)
536            else:
537                output += linear(input)
538        return output
539
540
541@requires_cuda
542@torch._inductor.config.patch(
543    post_grad_fusion_options={"batch_linear_post_grad": {"require_fbgemm": False}}
544)
545class TestPostGradBatchLinearFusion(TestCase):
546    def test_batch_linear_post_grad_fusion(self):
547        pt1_module = TestBMMFusionModule().cuda()
548        inputs = []
549        for _ in range(10):
550            inputs.append(torch.randn(10, 10).cuda())
551        eager_output = pt1_module(inputs)
552        pt2_module = torch.compile(pt1_module)
553        pt2_output = pt2_module(inputs)
554        self.assertTrue(torch.allclose(eager_output, pt2_output))
555        self.assertEqual(
556            counters["inductor"]["batch_linear_post_grad"],
557            2,
558        )
559        self.assertIn("PostGradBatchLinearFusion", optimus_scuba_log)
560
561
562class TestFindIndependentSubsetGreedy(TestCase):
563    # Helper function to build a Graph from a data description.
564    def build_graph(self, desc):
565        # desc: {
566        #   "n1": ["n2", "n3"],
567        #   "n2": ["n3"],
568        #   "n3": [],
569        # }
570        #
571        g = torch.fx.Graph()
572        lookup = {}
573        desc = collections.deque((k, v) for k, v in desc.items())
574        unsatisfied = 0
575        while desc:
576            unsatisfied += 1
577            assert unsatisfied <= len(desc)  # cycle or bad input?
578            name, v = desc.popleft()
579            args = tuple(lookup.get(n, None) for n in v)
580            if None in args:
581                desc.append((name, v))
582                continue
583            node = g.create_node("placeholder", "target", name=name, args=args)
584            lookup[name] = node
585            unsatisfied = 0
586        return g, lookup
587
588    def verify(self, tree, subnodes, min_fuse, max_fuse, expected):
589        g, lookup = self.build_graph(tree)
590        subnodes = [lookup[n] for n in subnodes]
591        expected = [[lookup[n] for n in sub] for sub in expected]
592        opts = {
593            "min_fuse_set_size": min_fuse,
594            "max_fuse_set_size": max_fuse,
595        }
596        result = list(
597            torch._inductor.fx_passes.group_batch_fusion.find_independent_subset_greedy(
598                subnodes, opts
599            )
600        )
601        self.assertEqual(expected, result)
602
603    def test_find_independent_subset_greedy(self):
604        # First some randomly generated tests.
605        self.verify({"n0": (), "n1": ()}, ["n0"], 0, 100, [["n0"]])
606        self.verify(
607            {"n0": (), "n1": (), "n2": ("n0",)}, ["n1", "n2"], 0, 100, [["n1", "n2"]]
608        )
609        self.verify(
610            {
611                "n0": (),
612                "n1": (),
613                "n2": ("n0",),
614                "n3": (),
615                "n4": ("n0", "n1", "n2"),
616                "n5": ("n0", "n2", "n4"),
617                "n6": ("n3",),
618                "n7": ("n4", "n5", "n6", "n1", "n3"),
619                "n8": ("n7", "n1", "n3", "n5", "n0"),
620                "n9": ("n3", "n4", "n8", "n6", "n5", "n2", "n0", "n7"),
621                "n10": ("n0",),
622                "n11": ("n4", "n0", "n2", "n3", "n1", "n9"),
623                "n12": ("n2", "n3", "n10", "n6", "n9"),
624            },
625            ["n10", "n5", "n3", "n4", "n9"],
626            0,
627            100,
628            [["n10", "n5", "n3"], ["n4"], ["n9"]],
629        )
630        self.verify({"n0": (), "n1": (), "n2": ("n0",)}, ["n2"], 0, 100, [["n2"]])
631        self.verify(
632            {
633                "n0": (),
634                "n1": (),
635                "n2": (),
636                "n3": (),
637                "n4": ("n3", "n1", "n0"),
638                "n5": ("n1", "n2", "n4", "n0"),
639                "n6": ("n0", "n3", "n2"),
640                "n7": ("n6", "n1", "n5", "n4", "n3", "n0"),
641                "n8": ("n2", "n7", "n3"),
642                "n9": ("n3", "n5", "n6", "n7", "n2", "n1"),
643                "n10": ("n8", "n0", "n2", "n4", "n6", "n3"),
644                "n11": ("n6", "n5", "n8", "n1", "n3", "n10", "n2"),
645                "n12": ("n7", "n4"),
646            },
647            ["n7"],
648            0,
649            100,
650            [["n7"]],
651        )
652        self.verify(
653            {
654                "n0": (),
655                "n1": (),
656                "n2": (),
657                "n3": ("n1", "n2"),
658                "n4": ("n1",),
659                "n5": (),
660                "n6": ("n5",),
661                "n7": ("n1", "n6", "n5", "n2", "n3", "n0"),
662                "n8": ("n5", "n7", "n2", "n6"),
663                "n9": ("n1",),
664                "n10": ("n9",),
665                "n11": ("n3", "n4", "n0", "n2"),
666                "n12": ("n8", "n9", "n5", "n1"),
667                "n13": ("n11", "n4", "n12", "n1", "n9", "n3", "n0"),
668            },
669            ["n9", "n2", "n8", "n10", "n5", "n6", "n13", "n7", "n3", "n0", "n4"],
670            0,
671            100,
672            [
673                ["n9", "n2", "n5", "n0", "n4"],
674                ["n8", "n10"],
675                ["n6", "n3"],
676                ["n13"],
677                ["n7"],
678            ],
679        )
680        self.verify({"n0": ()}, ["n0"], 0, 100, [["n0"]])
681        self.verify(
682            {
683                "n0": (),
684                "n1": (),
685                "n2": (),
686                "n3": (),
687                "n4": ("n1", "n2"),
688                "n5": ("n0", "n4", "n1"),
689                "n6": ("n1", "n5"),
690                "n7": (),
691                "n8": ("n7", "n1", "n3", "n5", "n6"),
692                "n9": ("n2", "n1", "n8", "n0", "n4", "n7", "n6", "n5"),
693                "n10": ("n4", "n7", "n2", "n3", "n8"),
694                "n11": (),
695                "n12": ("n9", "n7", "n5", "n11", "n8"),
696                "n13": (
697                    "n5",
698                    "n6",
699                    "n12",
700                    "n3",
701                    "n9",
702                    "n8",
703                    "n4",
704                    "n11",
705                    "n2",
706                    "n10",
707                    "n1",
708                ),
709                "n14": ("n7", "n3", "n12", "n10", "n2", "n0", "n4", "n5"),
710                "n15": ("n9", "n5", "n1", "n13", "n8", "n10", "n12", "n7", "n11", "n3"),
711                "n16": (
712                    "n2",
713                    "n4",
714                    "n15",
715                    "n5",
716                    "n0",
717                    "n6",
718                    "n3",
719                    "n8",
720                    "n14",
721                    "n12",
722                    "n9",
723                    "n10",
724                    "n7",
725                    "n13",
726                ),
727            },
728            ["n0", "n3", "n2", "n11", "n1", "n6", "n12", "n5", "n4", "n15", "n8"],
729            0,
730            100,
731            [
732                ["n0", "n3", "n2", "n11", "n1"],
733                ["n6"],
734                ["n12"],
735                ["n5"],
736                ["n4"],
737                ["n15"],
738                ["n8"],
739            ],
740        )
741        self.verify(
742            {
743                "n0": (),
744                "n1": (),
745                "n2": (),
746                "n3": ("n2", "n1"),
747                "n4": ("n2", "n3", "n1"),
748                "n5": ("n3", "n1"),
749                "n6": ("n1",),
750                "n7": ("n5", "n4"),
751                "n8": ("n6", "n2"),
752            },
753            ["n4", "n3", "n1", "n8", "n5", "n6", "n2"],
754            0,
755            100,
756            [["n4", "n8", "n5"], ["n3", "n6"], ["n1", "n2"]],
757        )
758        self.verify(
759            {
760                "n0": (),
761                "n1": (),
762                "n2": (),
763                "n3": ("n1", "n0"),
764                "n4": ("n0",),
765                "n5": ("n1", "n4"),
766                "n6": ("n2", "n1", "n4"),
767                "n7": ("n0", "n3"),
768                "n8": ("n5", "n0", "n6", "n1", "n4", "n2", "n3"),
769                "n9": ("n1", "n4", "n8", "n7", "n5"),
770                "n10": ("n9", "n8", "n0", "n2", "n7", "n1", "n3", "n5"),
771                "n11": ("n9", "n2", "n6", "n0", "n3"),
772                "n12": ("n1", "n4", "n7", "n10", "n5", "n2", "n11", "n6"),
773                "n13": ("n9", "n2", "n3", "n0", "n7", "n5", "n10", "n11"),
774                "n14": (
775                    "n8",
776                    "n0",
777                    "n3",
778                    "n6",
779                    "n10",
780                    "n1",
781                    "n5",
782                    "n9",
783                    "n12",
784                    "n11",
785                    "n4",
786                ),
787                "n15": (
788                    "n3",
789                    "n10",
790                    "n0",
791                    "n4",
792                    "n9",
793                    "n11",
794                    "n2",
795                    "n13",
796                    "n12",
797                    "n8",
798                    "n5",
799                    "n14",
800                ),
801                "n16": ("n6",),
802                "n17": (
803                    "n4",
804                    "n3",
805                    "n14",
806                    "n8",
807                    "n15",
808                    "n16",
809                    "n2",
810                    "n5",
811                    "n7",
812                    "n12",
813                    "n1",
814                    "n0",
815                    "n11",
816                ),
817            },
818            ["n17", "n16", "n10", "n4", "n8", "n12", "n6", "n1"],
819            0,
820            100,
821            [["n17"], ["n16", "n10"], ["n4", "n1"], ["n8"], ["n12"], ["n6"]],
822        )
823        self.verify(
824            {
825                "n0": (),
826                "n1": (),
827                "n2": ("n0",),
828                "n3": ("n0", "n1"),
829                "n4": ("n0",),
830                "n5": ("n0",),
831                "n6": ("n5", "n3", "n0", "n2"),
832                "n7": (),
833                "n8": ("n2", "n5", "n3", "n1", "n7", "n6", "n0"),
834                "n9": ("n4",),
835                "n10": ("n4", "n5", "n1", "n2", "n0", "n6", "n8", "n9", "n7"),
836                "n11": ("n3", "n0", "n9", "n10", "n5", "n1", "n2", "n7", "n4", "n6"),
837                "n12": ("n9", "n5"),
838            },
839            ["n8", "n3", "n1", "n12", "n2", "n5", "n11", "n4", "n10", "n6", "n0"],
840            0,
841            100,
842            [
843                ["n8", "n12"],
844                ["n3", "n2", "n5", "n4"],
845                ["n1", "n0"],
846                ["n11"],
847                ["n10"],
848                ["n6"],
849            ],
850        )
851        self.verify(
852            {
853                "n0": (),
854                "n1": (),
855                "n2": (),
856                "n3": (),
857                "n4": ("n2", "n3"),
858                "n5": ("n1", "n3", "n2", "n4"),
859                "n6": ("n5", "n4", "n1", "n3"),
860                "n7": ("n5",),
861                "n8": ("n5", "n4", "n1"),
862                "n9": ("n2", "n3", "n1", "n5", "n7", "n0", "n8"),
863                "n10": ("n5", "n3", "n1", "n7", "n8", "n9"),
864                "n11": ("n1", "n4", "n2", "n0", "n8", "n9"),
865                "n12": ("n4", "n3", "n9"),
866                "n13": (
867                    "n6",
868                    "n10",
869                    "n4",
870                    "n8",
871                    "n0",
872                    "n11",
873                    "n12",
874                    "n7",
875                    "n3",
876                    "n2",
877                    "n1",
878                ),
879                "n14": ("n4", "n13", "n2"),
880                "n15": ("n11", "n7", "n6", "n10", "n14"),
881                "n16": ("n15", "n3"),
882                "n17": ("n10", "n2", "n7", "n0", "n5", "n6", "n9"),
883                "n18": (
884                    "n16",
885                    "n8",
886                    "n6",
887                    "n9",
888                    "n11",
889                    "n12",
890                    "n14",
891                    "n5",
892                    "n13",
893                    "n4",
894                    "n1",
895                ),
896            },
897            [
898                "n1",
899                "n0",
900                "n16",
901                "n6",
902                "n15",
903                "n9",
904                "n7",
905                "n4",
906                "n3",
907                "n11",
908                "n13",
909                "n17",
910                "n12",
911                "n18",
912            ],
913            0,
914            100,
915            [
916                ["n1", "n0", "n4"],
917                ["n16", "n17"],
918                ["n6", "n9"],
919                ["n15"],
920                ["n7"],
921                ["n3"],
922                ["n11", "n12"],
923                ["n13"],
924                ["n18"],
925            ],
926        )
927        self.verify(
928            {
929                "n0": (),
930                "n1": (),
931                "n2": (),
932                "n3": ("n2",),
933                "n4": ("n1",),
934                "n5": (),
935                "n6": ("n1", "n4"),
936                "n7": ("n5", "n1"),
937                "n8": ("n6",),
938                "n9": ("n6", "n1", "n2", "n0"),
939                "n10": ("n0", "n7"),
940                "n11": ("n0", "n4", "n3", "n5"),
941                "n12": ("n9", "n8", "n7", "n4", "n0"),
942            },
943            ["n8", "n9", "n11", "n2", "n4", "n0", "n7", "n5", "n1"],
944            0,
945            100,
946            [["n8", "n9", "n11", "n7"], ["n2", "n4", "n0", "n5"], ["n1"]],
947        )
948        self.verify(
949            {"n0": (), "n1": (), "n2": (), "n3": ("n0",), "n4": ("n3",)},
950            ["n1", "n2", "n4"],
951            0,
952            100,
953            [["n1", "n2", "n4"]],
954        )
955        self.verify(
956            {
957                "n0": (),
958                "n1": (),
959                "n2": ("n1",),
960                "n3": ("n2", "n1"),
961                "n4": ("n3",),
962                "n5": (),
963                "n6": ("n1", "n5"),
964                "n7": (),
965                "n8": ("n4", "n5"),
966                "n9": ("n0", "n3", "n6", "n4", "n5", "n8", "n7", "n1"),
967                "n10": ("n3", "n0", "n6", "n9", "n7"),
968                "n11": (),
969                "n12": ("n1", "n8", "n3", "n6", "n7", "n0", "n10", "n5", "n9", "n11"),
970                "n13": ("n9", "n11", "n4"),
971                "n14": (),
972                "n15": ("n6", "n12"),
973                "n16": (
974                    "n1",
975                    "n7",
976                    "n10",
977                    "n3",
978                    "n9",
979                    "n0",
980                    "n2",
981                    "n5",
982                    "n8",
983                    "n13",
984                    "n14",
985                    "n15",
986                    "n4",
987                    "n6",
988                ),
989            },
990            [
991                "n11",
992                "n16",
993                "n5",
994                "n12",
995                "n7",
996                "n2",
997                "n0",
998                "n6",
999                "n3",
1000                "n9",
1001                "n8",
1002                "n15",
1003                "n14",
1004                "n4",
1005                "n13",
1006                "n1",
1007            ],
1008            0,
1009            100,
1010            [
1011                ["n11", "n5", "n7", "n2", "n0", "n14"],
1012                ["n16"],
1013                ["n12", "n13"],
1014                ["n6", "n3"],
1015                ["n9"],
1016                ["n8"],
1017                ["n15"],
1018                ["n4"],
1019                ["n1"],
1020            ],
1021        )
1022        self.verify({"n0": (), "n1": ()}, ["n1"], 0, 100, [["n1"]])
1023        self.verify(
1024            {
1025                "n0": (),
1026                "n1": (),
1027                "n2": ("n1",),
1028                "n3": (),
1029                "n4": ("n0", "n2", "n3"),
1030                "n5": ("n2", "n3"),
1031                "n6": ("n3",),
1032            },
1033            ["n6", "n2", "n3", "n1"],
1034            0,
1035            100,
1036            [["n6", "n2"], ["n3", "n1"]],
1037        )
1038        self.verify(
1039            {
1040                "n0": (),
1041                "n1": (),
1042                "n2": (),
1043                "n3": ("n2",),
1044                "n4": ("n0",),
1045                "n5": ("n1", "n2"),
1046                "n6": ("n2", "n3", "n1", "n0", "n5"),
1047                "n7": ("n6", "n2", "n0", "n4", "n5", "n1"),
1048                "n8": ("n4",),
1049                "n9": ("n4", "n6", "n7", "n1", "n2"),
1050            },
1051            ["n8", "n6", "n2", "n4", "n7", "n5", "n3", "n9"],
1052            0,
1053            100,
1054            [["n8", "n6"], ["n2", "n4"], ["n7"], ["n5", "n3"], ["n9"]],
1055        )
1056        self.verify(
1057            {
1058                "n0": (),
1059                "n1": (),
1060                "n2": (),
1061                "n3": ("n1", "n2"),
1062                "n4": ("n0",),
1063                "n5": ("n2", "n3", "n0", "n1"),
1064                "n6": ("n4", "n1"),
1065                "n7": ("n5",),
1066                "n8": ("n7", "n1", "n5", "n6", "n3", "n4", "n0"),
1067                "n9": ("n2", "n8"),
1068            },
1069            ["n1", "n7", "n4", "n2", "n0", "n8", "n3", "n5"],
1070            0,
1071            100,
1072            [["n1", "n4", "n2"], ["n7"], ["n0", "n3"], ["n8"], ["n5"]],
1073        )
1074        self.verify(
1075            {
1076                "n0": (),
1077                "n1": (),
1078                "n2": ("n0",),
1079                "n3": ("n1",),
1080                "n4": ("n2", "n1"),
1081                "n5": (),
1082                "n6": ("n0",),
1083                "n7": ("n6", "n3", "n2", "n1", "n0"),
1084                "n8": ("n0", "n2"),
1085                "n9": ("n6", "n5", "n8", "n4", "n0"),
1086                "n10": ("n1", "n7", "n5", "n8", "n6", "n2", "n4", "n9"),
1087            },
1088            ["n0"],
1089            0,
1090            100,
1091            [["n0"]],
1092        )
1093
1094        # trivial test of min_fuse
1095        self.verify(
1096            {
1097                "n0": (),
1098                "n1": (),
1099                "n2": (),
1100                "n3": ("n1", "n2"),
1101                "n4": ("n1",),
1102                "n5": (),
1103                "n6": ("n5",),
1104                "n7": ("n1", "n6", "n5", "n2", "n3", "n0"),
1105                "n8": ("n5", "n7", "n2", "n6"),
1106                "n9": ("n1",),
1107                "n10": ("n9",),
1108                "n11": ("n3", "n4", "n0", "n2"),
1109                "n12": ("n8", "n9", "n5", "n1"),
1110                "n13": ("n11", "n4", "n12", "n1", "n9", "n3", "n0"),
1111            },
1112            ["n9", "n2", "n8", "n10", "n5", "n6", "n13", "n7", "n3", "n0", "n4"],
1113            2,
1114            10,
1115            [["n9", "n2", "n5", "n0", "n4"], ["n8", "n10"], ["n6", "n3"]],
1116        )
1117
1118        # trivial test of max_fuse
1119        self.verify(
1120            {
1121                "n0": (),
1122                "n1": (),
1123                "n2": (),
1124                "n3": ("n1", "n2"),
1125                "n4": ("n1",),
1126                "n5": (),
1127                "n6": ("n5",),
1128                "n7": ("n1", "n6", "n5", "n2", "n3", "n0"),
1129                "n8": ("n5", "n7", "n2", "n6"),
1130                "n9": ("n1",),
1131                "n10": ("n9",),
1132                "n11": ("n3", "n4", "n0", "n2"),
1133                "n12": ("n8", "n9", "n5", "n1"),
1134                "n13": ("n11", "n4", "n12", "n1", "n9", "n3", "n0"),
1135            },
1136            ["n9", "n2", "n8", "n10", "n5", "n6", "n13", "n7", "n3", "n0", "n4"],
1137            0,
1138            3,
1139            [
1140                ["n9", "n2", "n5"],
1141                ["n8", "n10", "n4"],
1142                ["n6", "n3", "n0"],
1143                ["n13"],
1144                ["n7"],
1145            ],
1146        )
1147
1148    def test_find_independent_subset_greedy_fuse(self):
1149        # ensure that fusing the sets during iteration results in the correct
1150        # iteration results. In the example graph after we merge n2 and n3,
1151        # n4 is no longer independent from n1.
1152        g, lookup = self.build_graph(
1153            {
1154                "n0": (),
1155                "n1": (),
1156                "n2": ("n0",),
1157                "n3": ("n1",),
1158                "n4": ("n2",),
1159                "n5": (),
1160            }
1161        )
1162        opts = {
1163            "min_fuse_set_size": 0,
1164            "max_fuse_set_size": 100,
1165        }
1166        subnodes = ["n2", "n3", "n4", "n0", "n1", "n5"]
1167        subnodes = [lookup[n] for n in subnodes]
1168        i = torch._inductor.fx_passes.group_batch_fusion.find_independent_subset_greedy(
1169            subnodes, opts
1170        )
1171        self.assertEqual(next(i), [lookup[n] for n in ["n2", "n3", "n5"]])
1172
1173        # fuse n2 and n3 which makes n4 now dependant on n1.
1174        args = tuple(lookup[n] for n in ["n0", "n1"])
1175        fused = g.create_node("placeholder", "target", name="n2+n3", args=args)
1176        lookup["n2"].replace_all_uses_with(fused)
1177        g.erase_node(lookup["n2"])
1178        lookup["n3"].replace_all_uses_with(fused)
1179        g.erase_node(lookup["n3"])
1180
1181        self.assertEqual(next(i), [lookup[n] for n in ["n4"]])
1182        self.assertEqual(next(i), [lookup[n] for n in ["n0", "n1"]])
1183        self.assertRaises(StopIteration, lambda: next(i))
1184
1185
1186if __name__ == "__main__":
1187    run_tests()
1188