xref: /aosp_15_r20/external/pytorch/test/fx/test_fx_const_fold.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: fx"]
2
3import operator
4
5import torch
6import torch.fx
7from torch.fx.experimental import const_fold
8from torch.fx.passes.shape_prop import _extract_tensor_metadata, ShapeProp
9from torch.testing._internal.common_utils import TestCase
10
11
12class TestConstFold(TestCase):
13    def _get_attr(self, node):
14        mod = node.graph.owning_module
15        target = str(node.target)
16        target_atoms = target.split(".")
17        curr_obj = mod
18        for i, atom in enumerate(target_atoms):
19            if not hasattr(curr_obj, atom):
20                raise RuntimeError(
21                    f"Node referenced nonexistent target '{'.'.join(target_atoms[:i])}'; "
22                    f" original whole target: '{target}'"
23                )
24            curr_obj = getattr(curr_obj, atom)
25        return curr_obj
26
27    def _verify_const_fold_mod(self, mod_folded: const_fold.FoldedGraphModule):
28        self.assertTrue(mod_folded.const_subgraph_module is not None)
29
30        # Check that we don't have the const or non-const fold graphs in the gm, and
31        # that we do have the const folded get_attr.
32        found_folded_attrs = False
33        for n in mod_folded.graph.nodes:
34            if n.op == "get_attr" and n.target.startswith("_FX_CONST_FOLDED_ATTRS"):
35                found_folded_attrs = True
36            elif n.op == "call_module":
37                self.assertTrue(n.target not in {"submod_0", "submod_1"})
38        self.assertTrue(found_folded_attrs)
39
40    def test_const_fold_basic_one_attr_no_name_collision(self):
41        r"""
42        Perform constant folding conversion, from original mod to split constant folding
43        module with two split subgraphs, where there's a single attr to fold and
44        a single output attr result to replace.
45
46           attr1                 attr1
47            | |                   | |
48        x   add                   add
49         \ /                       |
50         sub   y                 output     (becomes attr add_1)
51            \ /         ==> -------+------- (const/base subgraph split)
52            mul  attr2       x   /          (input from previous subgraph
53              \ /             \ /            is attr)
54              add             sub   y
55               |                 \ /
56             output              mul  attr2
57                                   \ /
58                                   add
59                                    |
60                                  output
61        """
62
63        class ConstFoldTestModule(torch.nn.Module):
64            def __init__(self) -> None:
65                super().__init__()
66                self.attr_1 = torch.nn.Parameter(torch.tensor([[-0.9]]))
67                self.attr_2 = torch.nn.Parameter(torch.tensor([[17.1]]))
68
69            def forward(self, x, y):
70                a = self.attr_1 + self.attr_1
71                x = x - a
72                return x * y + self.attr_2
73
74        mod = ConstFoldTestModule()
75        mod_folded: const_fold.FoldedGraphModule = const_fold.split_const_subgraphs(mod)
76        self._verify_const_fold_mod(mod_folded)
77
78        # Now run both folded and non-folded to check results equal.
79        in_x, in_y = torch.tensor([[-0.45]]), torch.tensor([0.9])
80        base_result = mod(in_x, in_y)
81        fold_result = mod_folded(in_x, in_y)
82        self.assertTrue(torch.equal(fold_result, base_result))
83
84    def test_const_fold_basic_one_attr_name_collision(self):
85        r"""
86        Perform constant folding conversion, from original mod to split constant folding
87        module with two split subgraphs, where there's a single attr to fold and
88        a single output attr result to replace. Name the attrs such that they will
89        collide by name with folded attrs.
90
91           add_1                 add_1
92            | |                   | |
93        x   add                   add
94         \ /                       |
95         sub   y                 output     (becomes attr add_1)
96            \ /         ==> -------+------- (const/base subgraph split)
97            mul  add_2       x   /          (input from previous subgraph
98              \ /             \ /            is attr)
99              add             sub   y
100               |                 \ /
101             output              mul  add_2
102                                   \ /
103                                   add
104                                    |
105                                  output
106        """
107
108        class ConstFoldTestModule(torch.nn.Module):
109            def __init__(self) -> None:
110                super().__init__()
111                # Note: Named as such to result in name collision.
112                self.add_1__CF = torch.nn.Parameter(torch.tensor([[1.0]]))
113                self.add_2__CF = torch.nn.Parameter(torch.tensor([[17.1]]))
114
115            def forward(self, x, y):
116                a = self.add_1__CF + self.add_1__CF
117                x = x - a
118                return x * y + self.add_2__CF
119
120        mod = ConstFoldTestModule()
121        mod_folded: const_fold.FoldedGraphModule = const_fold.split_const_subgraphs(mod)
122        self._verify_const_fold_mod(mod_folded)
123
124        # Now run both folded and non-folded to check results equal.
125        in_x, in_y = torch.tensor([[5.0]]), torch.tensor([4.0])
126        base_result = mod(in_x, in_y)
127        fold_result = mod_folded(in_x, in_y)
128        self.assertTrue(torch.equal(fold_result, base_result))
129
130    def test_const_fold_basic_placeholder_reordered(self):
131        """
132        Test code path where placeholder comes after normal op node in FX
133        """
134
135        class ConstFoldTestModule(torch.nn.Module):
136            def forward(self, x, y):
137                return x * 2 + y
138
139        mod = ConstFoldTestModule()
140        mod = torch.fx.symbolic_trace(mod)
141        yy = None
142        for n in mod.graph.nodes:
143            if n.op == "placeholder" and n.target == "y":
144                yy = n
145            elif yy is not None and n.op == "call_function":
146                yy.prepend(n)
147                break
148
149        mod_folded: const_fold.FoldedGraphModule = const_fold.split_const_subgraphs(mod)
150
151        self.assertTrue(mod_folded.const_subgraph_module is None)
152        # Now run both folded and non-folded to check results equal.
153        in_x = torch.tensor([[-0.45]])
154        in_y = torch.tensor([[0.45]])
155        base_result = mod(in_x, in_y)
156        fold_result = mod_folded(in_x, in_y)
157        self.assertTrue(torch.equal(fold_result, base_result))
158
159    def test_const_fold_noop(self):
160        r"""
161        Check that a graph with no constant folding is handled correctly.
162
163        x  attr1
164         \ /
165         sub
166          |
167        output
168        """
169
170        class ConstFoldTestModule(torch.nn.Module):
171            def __init__(self) -> None:
172                super().__init__()
173                self.attr1 = torch.nn.Parameter(torch.tensor([[-0.9]]))
174
175            def forward(self, x):
176                return x - self.attr1
177
178        mod = ConstFoldTestModule()
179        mod_folded: const_fold.FoldedGraphModule = const_fold.split_const_subgraphs(mod)
180
181        # Check that the folded graph module is None, since there was no folding to do.
182        self.assertTrue(mod_folded.const_subgraph_module is None)
183
184        # Now run both folded and non-folded to check results equal.
185        in_x = torch.tensor([[-0.45]])
186        base_result = mod(in_x)
187        fold_result = mod_folded(in_x)
188        self.assertTrue(torch.equal(fold_result, base_result))
189
190    def test_const_fold_basic_two_attr_three_input(self):
191        r"""
192        Perform constant folding conversion, from original mod to split constant
193        folding module with two split subgraphs, where there are two attrs to
194        fold into a single output, and there are three placeholder inputs.
195
196        attr1   attr2         attr1   attr2
197            \   /                 \   /
198         x   add                   add
199          \ /                       |
200          sub     y               output     (becomes attr add_1)
201             \   /     ==>   -------+------- (const/base subgraph split)
202              mul  z           x   /         (input from previous subgraph
203                \ /             \ /           is attr)
204                div              sub  y
205                 |                 \ /
206               output              mul  z
207                                     \ /
208                                     div
209                                      |
210                                    output
211        """
212
213        class ConstFoldTestModule(torch.nn.Module):
214            def __init__(self) -> None:
215                super().__init__()
216                self.attr1 = torch.nn.Parameter(torch.tensor([[-0.9]]))
217                self.attr1 = torch.nn.Parameter(torch.tensor([[1.32]]))
218
219            def forward(self, x, y, z):
220                a = self.attr1 + self.attr1
221                sub = x - a
222                mul = sub * y
223                return mul / z
224
225        mod = ConstFoldTestModule()
226        mod_folded: const_fold.FoldedGraphModule = const_fold.split_const_subgraphs(mod)
227        self._verify_const_fold_mod(mod_folded)
228
229        # Now run both folded and non-folded to check results equal.
230        in_x, in_y, in_z = (
231            torch.tensor([[-0.45]]),
232            torch.tensor([0.9]),
233            torch.tensor([1.1]),
234        )
235        base_result = mod(in_x, in_y, in_z)
236        fold_result = mod_folded(in_x, in_y, in_z)
237        self.assertTrue(torch.equal(fold_result, base_result))
238
239    def test_const_fold_basic_two_attr(self):
240        r"""
241        Perform constant folding conversion, from original mod to split constant
242        folding module with two split subgraphs, where there are two attrs to
243        fold into a single output.
244
245        attr1  attr2                attr1  attr2
246            \ /                         \ /
247        x   add                         add       (becomes attr add_1)
248         \ /            ==>       -------+------- (const/base subgraph split)
249         sub                         x   |        (input from previous subgraph is attr)
250          |                           \ /
251        output                        sub
252                                       |
253                                     output
254        """
255
256        class ConstFoldTestModule(torch.nn.Module):
257            def __init__(self) -> None:
258                super().__init__()
259                self.attr1 = torch.nn.Parameter(torch.randn(2, 3))
260                self.attr2 = torch.nn.Parameter(torch.randn(2, 3))
261
262            def forward(self, x):
263                y = self.attr1 + self.attr2
264                return x + y
265
266        mod = ConstFoldTestModule()
267        mod_folded: const_fold.FoldedGraphModule = const_fold.split_const_subgraphs(mod)
268        self._verify_const_fold_mod(mod_folded)
269
270        # Now run both folded and non-folded to check results equal.
271        in_x = torch.randn(2, 3)
272        fold_result = mod_folded(in_x)
273        base_result = mod(in_x)
274        self.assertTrue(torch.equal(fold_result, base_result))
275
276    def test_const_fold_multi_const_folded_attrs(self):
277        r"""
278        Perform constant folding conversion, from original mod to split constant
279        folding module with two split subgraphs, where there are two attrs to
280        fold into two new attrs.
281
282           attr1        attr2          attr1     attr2
283           /    \         |           /     \      |
284        permute  |       sum       permute   |    sum
285            \   /        /                \ /      |
286         x   add    y   /                 add      |
287          \ /        \ /                   |       |
288          sub        add                 output  output     (become attrs add_1 and mul_1)
289             \       /        ==>   --------+-------+------ (const/base subgraph split)
290              \     /                   x   |   y   |       (inputs from previous subgraph
291                add                      \ /     \ /         are attrs)
292                 |                       sub     add
293               linear                       \   /
294                 |                           add
295               sigmoid                        |
296                 |                          linear
297               output                         |
298                                            sigmoid
299                                              |
300                                            output
301        """
302
303        class ConstFoldTestModule(torch.nn.Module):
304            def __init__(self) -> None:
305                super().__init__()
306                self.attr1 = torch.nn.Parameter(torch.randn(4, 4))
307                self.attr2 = torch.nn.Parameter(torch.randn(4, 4))
308                self.lin = torch.nn.Linear(4, 4)
309
310            def forward(self, x, y):
311                a = self.attr1 + self.attr1.permute(1, 0)
312                x = x - a
313                amax = torch.sum(self.attr2, dim=1)
314                y = y + amax
315                return torch.sigmoid(self.lin(x + y))
316
317        mod = ConstFoldTestModule()
318        mod_folded: const_fold.FoldedGraphModule = const_fold.split_const_subgraphs(mod)
319        self._verify_const_fold_mod(mod_folded)
320
321        # Now run both folded and non-folded to check results equal.
322        in_x, in_y = torch.randn(4, 4), torch.randn(4)
323        fold_result = mod_folded(in_x, in_y)
324        base_result = mod(in_x, in_y)
325        self.assertTrue(torch.equal(fold_result, base_result))
326
327    def test_const_fold_submod_hierarchy(self):
328        r"""
329        Perform constant folding conversion, from original mod to split constant folding
330        module where one of the folded attrs comes from a submod deeper in the hierarchy
331        of the base module.
332        """
333
334        class TracedThroughModule(torch.nn.Module):
335            def __init__(self) -> None:
336                super().__init__()
337                self.internal_attr = torch.nn.Parameter(torch.randn(2, 3))
338
339            def forward(self):
340                return self.internal_attr
341
342        class ConstFoldTestModule(torch.nn.Module):
343            def __init__(self) -> None:
344                super().__init__()
345                self.my_mod = TracedThroughModule()
346                self.attr = torch.nn.Parameter(torch.randn(2, 3))
347
348            def forward(self, x):
349                return self.attr + self.my_mod() + x
350
351        mod = ConstFoldTestModule()
352        mod_folded: const_fold.FoldedGraphModule = const_fold.split_const_subgraphs(mod)
353        self._verify_const_fold_mod(mod_folded)
354
355        # Now run both folded and non-folded to check results equal.
356        in_x = torch.randn(2, 3)
357        fold_result = mod_folded(in_x)
358        base_result = mod(in_x)
359        self.assertTrue(torch.equal(fold_result, base_result))
360
361    def test_retain_node_meta(self):
362        r"""
363        Perform constant folding conversion, and validate that node meta is retained.
364        """
365
366        class ConstFoldTestModule(torch.nn.Module):
367            def __init__(self) -> None:
368                super().__init__()
369                self.attr = torch.nn.Parameter(torch.randn(2, 3))
370
371            def forward(self, x):
372                a = self.attr + self.attr
373                return x - a
374
375        mod = ConstFoldTestModule()
376        gm = torch.fx.symbolic_trace(mod)
377
378        # Add a count for each node to check after we const fold.
379        for idx, node in enumerate(gm.graph.nodes):
380            if node.op != "output":
381                node.meta["meta_idx"] = idx
382
383        # Pre-folding:
384        # idx 0: placeholder
385        # idx 1: get_attr (will no longer be used, hence removed)
386        # idx 2: add (will be folded into a get_attr)
387        # idx 3: sub
388
389        gm_folded: const_fold.FoldedGraphModule = const_fold.split_const_subgraphs(gm)
390        self._verify_const_fold_mod(gm_folded)
391
392        # Post-folding:
393        # idx 0: placeholder
394        # idx 2: get_attr (replaced original add; original get_attr was removed)
395        # idx 3: sub
396
397        # Check the expected indices are still here.
398        for node in gm_folded.graph.nodes:
399            if node.op == "placeholder":
400                self.assertEqual(node.meta["meta_idx"], 0)
401            elif node.op == "get_attr":
402                self.assertEqual(node.meta["meta_idx"], 2)
403            elif node.op == "call_function" and node.target == operator.sub:
404                self.assertEqual(node.meta["meta_idx"], 3)
405            else:
406                self.assertEqual(node.op, "output")
407
408        # Now run both folded and non-folded to check results equal.
409        in_x = torch.randn(2, 3)
410        fold_result = gm_folded(in_x)
411        base_result = mod(in_x)
412        self.assertTrue(torch.equal(fold_result, base_result))
413
414    def test_const_fold_has_inlined_call_module_node(self):
415        class ConstFoldTestModule(torch.nn.Module):
416            def __init__(self) -> None:
417                super().__init__()
418                self.attr = torch.nn.Parameter(torch.randn(2, 3))
419                self.mod = torch.nn.Identity()
420                self.mod.relu = torch.nn.ReLU()
421
422            def forward(self, x):
423                a = self.attr + self.attr
424                return self.mod.relu(x - a)
425
426        mod = ConstFoldTestModule()
427        gm_folded = const_fold.split_const_subgraphs(mod)
428
429        # Now run both folded and non-folded to check results equal.
430        in_x = torch.randn(2, 3)
431        fold_result = gm_folded(in_x)
432        base_result = mod(in_x)
433        self.assertTrue(torch.equal(fold_result, base_result))
434
435    def test_const_fold_module_attr(self):
436        class ConstFoldTestModule(torch.nn.Module):
437            def __init__(self) -> None:
438                super().__init__()
439                self.const = torch.nn.Parameter(torch.randn(2, 3))
440                self.mod = torch.nn.Identity()
441                self.mod.attr = torch.nn.Parameter(torch.randn(2, 3))
442
443            def forward(self, x):
444                a = self.const + self.mod.attr
445                x = x + a
446                return x + self.mod.attr
447
448        mod = ConstFoldTestModule()
449        gm_folded = const_fold.split_const_subgraphs(mod)
450
451        # Now run both folded and non-folded to check results equal.
452        in_x = torch.randn(2, 3)
453        fold_result = gm_folded(in_x)
454        base_result = mod(in_x)
455        self.assertTrue(torch.equal(fold_result, base_result))
456
457    def test_const_fold_unused_placeholder(self):
458        class ConstFoldTestModule(torch.nn.Module):
459            def __init__(self) -> None:
460                super().__init__()
461                self.const = torch.nn.Parameter(torch.randn(2, 3))
462
463            def forward(self, x, y, z):
464                a = self.const + self.const
465                return y + a
466
467        mod = ConstFoldTestModule()
468        gm_folded = const_fold.split_const_subgraphs(mod)
469
470        # Now run both folded and non-folded to check results equal.
471        in_x = torch.randn(2, 3)
472        fold_result = gm_folded(in_x, in_x, in_x)
473        base_result = mod(in_x, in_x, in_x)
474        self.assertTrue(torch.equal(fold_result, base_result))
475
476    def test_dict_output(self):
477        class ConstFoldTestModule(torch.nn.Module):
478            def __init__(self) -> None:
479                super().__init__()
480                self.const = torch.nn.Parameter(torch.randn(2, 3))
481
482            def forward(self, x):
483                a = self.const + self.const
484                return {"result": x + a}
485
486        mod = ConstFoldTestModule()
487        gm_folded = const_fold.split_const_subgraphs(mod)
488
489        # Now run both folded and non-folded to check results equal.
490        in_x = torch.randn(2, 3)
491        fold_result = gm_folded(in_x)
492        base_result = mod(in_x)
493        self.assertTrue(torch.equal(fold_result["result"], base_result["result"]))
494
495    def test_two_outputs(self):
496        class ConstFoldTestModule(torch.nn.Module):
497            def __init__(self) -> None:
498                super().__init__()
499                self.const = torch.nn.Parameter(torch.randn(2, 3))
500
501            def forward(self, x):
502                a = self.const + self.const
503                return x, x + a
504
505        mod = ConstFoldTestModule()
506        gm_folded = const_fold.split_const_subgraphs(mod)
507
508        # Now run both folded and non-folded to check results equal.
509        in_x = torch.randn(2, 3)
510        fold_result = gm_folded(in_x)
511        base_result = mod(in_x)
512        self.assertTrue(torch.equal(fold_result[0], base_result[0]))
513        self.assertTrue(torch.equal(fold_result[1], base_result[1]))
514
515    def test_three_outputs(self):
516        class ConstFoldTestModule(torch.nn.Module):
517            def __init__(self) -> None:
518                super().__init__()
519                self.const = torch.nn.Parameter(torch.randn(2, 3))
520
521            def forward(self, x):
522                a = self.const + self.const
523                return x, x + a, x + a
524
525        mod = ConstFoldTestModule()
526        gm_folded = const_fold.split_const_subgraphs(mod)
527
528        # Now run both folded and non-folded to check results equal.
529        in_x = torch.randn(2, 3)
530        fold_result = gm_folded(in_x)
531        base_result = mod(in_x)
532        self.assertTrue(torch.equal(fold_result[0], base_result[0]))
533        self.assertTrue(torch.equal(fold_result[1], base_result[1]))
534        self.assertTrue(torch.equal(fold_result[2], base_result[2]))
535
536    def test_check_inline_non_const(self):
537        r"""
538        Perform constant folding conversion and check that the non-const module is inlined
539        correctly.
540        """
541
542        class ConstFoldTestModule(torch.nn.Module):
543            def __init__(self) -> None:
544                super().__init__()
545                self.attr = torch.nn.Parameter(torch.randn(2, 3))
546
547            def forward(self, x):
548                a = self.attr + self.attr
549                return (x - a * x) / 2
550
551        mod = ConstFoldTestModule()
552        gm = torch.fx.symbolic_trace(mod)
553
554        gm_folded: const_fold.FoldedGraphModule = const_fold.split_const_subgraphs(gm)
555        self._verify_const_fold_mod(gm_folded)
556
557        # Check there are no call modules, because they've been inlined or extracted for
558        # const folding.
559        for node in gm_folded.graph.nodes:
560            self.assertNotEqual(node.op, "call_module")
561
562        # Now run both folded and non-folded to check results equal.
563        in_x = torch.randn(2, 3)
564        fold_result = gm_folded(in_x)
565        base_result = mod(in_x)
566        self.assertTrue(torch.equal(fold_result, base_result))
567
568    def test_check_inline_non_const_mult_return(self):
569        r"""
570        Perform constant folding conversion and check that the non-const module is inlined
571        correctly.
572        """
573
574        class ConstFoldTestModule(torch.nn.Module):
575            def __init__(self) -> None:
576                super().__init__()
577                self.attr = torch.nn.Parameter(torch.randn(2, 3))
578
579            def forward(self, x):
580                a = self.attr + self.attr
581                return x - a, x / 2
582
583        mod = ConstFoldTestModule()
584        gm = torch.fx.symbolic_trace(mod)
585
586        gm_folded: const_fold.FoldedGraphModule = const_fold.split_const_subgraphs(gm)
587        self._verify_const_fold_mod(gm_folded)
588
589        # Check there are no call modules, because they've been inlined or extracted for
590        # const folding.
591        for node in gm_folded.graph.nodes:
592            self.assertNotEqual(node.op, "call_module")
593
594        # Now run both folded and non-folded to check results equal.
595        in_x = torch.randn(2, 3)
596        fold_result = gm_folded(in_x)
597        base_result = mod(in_x)
598        self.assertTrue(torch.equal(fold_result[0], base_result[0]))
599        self.assertTrue(torch.equal(fold_result[1], base_result[1]))
600
601    def test_check_skip_folding_quant_dequant_pattern(self):
602        r"""
603        Set up skip_folding_quant_dequant function to skip quant/dequant pattern.
604        This example shows how to use skip_folding_node_fn.
605        """
606
607        class ConstFoldTestModule(torch.nn.Module):
608            def __init__(self) -> None:
609                super().__init__()
610                self.weight = torch.nn.Parameter(torch.randn(4, 4))
611                self.bias = torch.nn.Parameter(torch.randn(4))
612                self.relu = torch.nn.ReLU()
613
614            def forward(self, x):
615                quant_weight = torch.quantize_per_tensor(
616                    self.weight, 0.5, 3, torch.quint8
617                )
618                dequant_weight = torch.dequantize(quant_weight)
619                output = torch.nn.functional.linear(x, dequant_weight, self.bias)
620                return self.relu(output)
621
622        mod = ConstFoldTestModule()
623        in_x = torch.randn(2, 4)
624        gm = torch.fx.symbolic_trace(mod)
625
626        def skip_folding_quant_dequant(node: torch.fx.Node):
627            if node.target != torch.quantize_per_tensor:
628                return False
629            # If quantize_per_node -> dequantize, then skip folding.
630            for user in node.users:
631                if user.target == torch.dequantize:
632                    return True
633            return False
634
635        gm_folded: const_fold.FoldedGraphModule = const_fold.split_const_subgraphs(
636            gm, skip_folding_node_fn=skip_folding_quant_dequant
637        )
638
639        # Check that the folded graph module is None, since there was no folding to do.
640        self.assertTrue(gm_folded.const_subgraph_module is None)
641
642        # Now run both folded and non-folded to check results equal.
643        fold_result = gm_folded(in_x)
644        base_result = mod(in_x)
645        self.assertTrue(torch.equal(fold_result, base_result))
646
647    def test_fold_module(self):
648        r"""
649        Perform constant folding with a call_module node.
650        """
651
652        class ConstFoldTestModule(torch.nn.Module):
653            def __init__(self) -> None:
654                super().__init__()
655                self.lin_input = torch.nn.Parameter(torch.randn(4, 4))
656                self.lin = torch.nn.Linear(4, 4)
657
658            def forward(self, x):
659                return self.lin(self.lin_input) + x
660
661        mod = ConstFoldTestModule()
662        mod_folded: const_fold.FoldedGraphModule = const_fold.split_const_subgraphs(mod)
663        self._verify_const_fold_mod(mod_folded)
664
665        # Now run both folded and non-folded to check results equal.
666        inp = torch.randn(4, 4)
667        self.assertTrue(torch.equal(mod_folded(inp), mod(inp)))
668
669    def test_const_fold_tensor_meta(self):
670        self._test_const_fold_tensor_meta(True)
671        self._test_const_fold_tensor_meta(False)
672
673    def _test_const_fold_tensor_meta(self, requires_grad):
674        """
675        Verify tensor_meta is handled correctly.
676        """
677
678        class ConstFoldTestModule(torch.nn.Module):
679            def __init__(self) -> None:
680                super().__init__()
681                self.attr_1 = torch.nn.Parameter(torch.tensor([[-0.9]]), requires_grad)
682                self.attr_2 = torch.nn.Parameter(torch.tensor([[17.1]]), requires_grad)
683
684            def forward(self, x, y):
685                a = self.attr_1 + self.attr_1
686                x = x - a
687                return x * y + self.attr_2
688
689        mod = ConstFoldTestModule()
690        gm = torch.fx.symbolic_trace(mod)
691        in_x, in_y = torch.tensor([[-0.45]]), torch.tensor([0.9])
692        ShapeProp(gm).propagate(in_x, in_y)
693        mod_folded: const_fold.FoldedGraphModule = const_fold.split_const_subgraphs(
694            gm, device_for_folded_attrs="cpu"
695        )
696        self._verify_const_fold_mod(mod_folded)
697
698        mod_folded.run_folding()
699
700        for n in mod_folded.graph.nodes:
701            if n.op == "get_attr":
702                attr = self._get_attr(n)
703                self.assertEqual(_extract_tensor_metadata(attr), n.meta["tensor_meta"])
704
705        # Now run both folded and non-folded to check results equal.
706        base_result = mod(in_x, in_y)
707        fold_result = mod_folded(in_x, in_y)
708        self.assertTrue(torch.equal(fold_result, base_result))
709