xref: /aosp_15_r20/external/pytorch/test/inductor/test_config.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: inductor"]
2import math
3import unittest
4
5import torch
6from torch._inductor import config
7from torch._inductor.test_case import run_tests, TestCase
8from torch.testing._internal.inductor_utils import HAS_CPU
9
10
11def dummy_fn(x):
12    return torch.sigmoid(x + math.pi) / 10.0
13
14
15class DummyModule(torch.nn.Module):
16    def forward(self, x):
17        return dummy_fn(x)
18
19
20class TestInductorConfig(TestCase):
21    @classmethod
22    def setUpClass(cls):
23        super().setUpClass()
24        cls._saved_config = config.save_config()
25
26    def tearDown(self):
27        super().tearDown()
28        config.load_config(self._saved_config)
29
30    def test_set(self):
31        config.max_fusion_size = 13337
32        self.assertEqual(config.max_fusion_size, 13337)
33        self.assertEqual(config.shallow_copy_dict()["max_fusion_size"], 13337)
34        config.max_fusion_size = 32
35        self.assertEqual(config.max_fusion_size, 32)
36
37        # a nested config
38        prior = config.triton.cudagraphs
39        config.triton.cudagraphs = not prior
40        self.assertEqual(config.triton.cudagraphs, not prior)
41        self.assertEqual(config.shallow_copy_dict()["triton.cudagraphs"], not prior)
42
43    def test_save_load(self):
44        config.max_fusion_size = 123
45        config.triton.cudagraphs = True
46        saved1 = config.save_config()
47        config.max_fusion_size = 321
48        config.triton.cudagraphs = False
49        saved2 = config.save_config()
50
51        self.assertEqual(config.max_fusion_size, 321)
52        self.assertEqual(config.triton.cudagraphs, False)
53        config.load_config(saved1)
54        self.assertEqual(config.max_fusion_size, 123)
55        self.assertEqual(config.triton.cudagraphs, True)
56        config.load_config(saved2)
57        self.assertEqual(config.max_fusion_size, 321)
58        self.assertEqual(config.triton.cudagraphs, False)
59
60    def test_hasattr(self):
61        self.assertTrue(hasattr(config, "max_fusion_size"))
62        self.assertFalse(hasattr(config, "missing_name"))
63
64    def test_invalid_names(self):
65        self.assertRaises(AttributeError, lambda: config.does_not_exist)
66        self.assertRaises(AttributeError, lambda: config.triton.does_not_exist)
67
68        def store1():
69            config.does_not_exist = True
70
71        def store2():
72            config.triton.does_not_exist = True
73
74        self.assertRaises(AttributeError, store1)
75        self.assertRaises(AttributeError, store2)
76
77    def test_patch(self):
78        with config.patch(max_fusion_size=456):
79            self.assertEqual(config.max_fusion_size, 456)
80            with config.patch(max_fusion_size=789):
81                self.assertEqual(config.max_fusion_size, 789)
82            self.assertEqual(config.max_fusion_size, 456)
83
84        with config.patch({"cpp.threads": 9000, "max_fusion_size": 9001}):
85            self.assertEqual(config.cpp.threads, 9000)
86            self.assertEqual(config.max_fusion_size, 9001)
87            with config.patch("cpp.threads", 8999):
88                self.assertEqual(config.cpp.threads, 8999)
89            self.assertEqual(config.cpp.threads, 9000)
90
91    @unittest.skipIf(not HAS_CPU, "requires C++ compiler")
92    def test_compile_api(self):
93        # these are mostly checking config processing doesn't blow up with exceptions
94        x = torch.randn(8)
95        y = dummy_fn(x)
96        checks = [
97            {},
98            {"mode": "default"},
99            {"mode": "reduce-overhead"},
100            {"mode": "max-autotune"},
101            {
102                "options": {
103                    "max-fusion-size": 128,
104                    "unroll_reductions_threshold": 32,
105                    "triton.cudagraphs": False,
106                }
107            },
108            {"dynamic": True},
109            {"fullgraph": True, "backend": "inductor"},
110            {"disable": True},
111        ]
112
113        for kwargs in checks:
114            torch._dynamo.reset()
115            opt_fn = torch.compile(dummy_fn, **kwargs)
116            torch.testing.assert_allclose(
117                opt_fn(x), y, msg=f"torch.compile(..., **{kwargs!r}) failed"
118            )
119
120    def test_get_compiler_config(self):
121        from torch._inductor import config as inductor_default_config
122
123        default_cudagraphs = inductor_default_config._default["triton.cudagraphs"]
124
125        # nn.Module: should update default config with a new value
126        model = DummyModule()
127        optimized_module = torch.compile(
128            model, options={"triton.cudagraphs": not default_cudagraphs}
129        )
130        compiler_config = optimized_module.get_compiler_config()
131        self.assertEqual(compiler_config["triton.cudagraphs"], not default_cudagraphs)
132
133        # nn.Module: keep default config
134        model = DummyModule()
135        optimized_module = torch.compile(model)
136        compiler_config = optimized_module.get_compiler_config()
137        self.assertEqual(
138            compiler_config["triton.cudagraphs"],
139            default_cudagraphs,
140        )
141
142        # compile user func: should update default config with a new value
143        optimized_module = torch.compile(
144            dummy_fn, options={"triton.cudagraphs": not default_cudagraphs}
145        )
146        compiler_config = optimized_module.get_compiler_config()
147        self.assertEqual(compiler_config["triton.cudagraphs"], not default_cudagraphs)
148
149        # compile user func: keep default config
150        optimized_module = torch.compile(dummy_fn)
151        compiler_config = optimized_module.get_compiler_config()
152        self.assertEqual(
153            compiler_config["triton.cudagraphs"],
154            default_cudagraphs,
155        )
156
157        # backend=eager: expect None
158        optimized_module = torch.compile(dummy_fn, backend="eager")
159        compiler_config = optimized_module.get_compiler_config()
160        self.assertTrue(compiler_config is None)
161
162    def test_compile_api_passes_config(self):
163        # ensure configs are actually passed down to inductor
164        self.assertRaises(
165            torch._dynamo.exc.BackendCompilerFailed,
166            lambda: torch.compile(dummy_fn, options={"_raise_error_for_testing": True})(
167                torch.randn(10)
168            ),
169        )
170
171    def test_api_options(self):
172        reduce_overhead_opts = torch._inductor.list_mode_options("reduce-overhead")
173        self.assertEqual(reduce_overhead_opts["triton.cudagraphs"], True)
174        self.assertEqual(reduce_overhead_opts.get("max_autotune", False), False)
175
176        max_autotune_opts = torch._inductor.list_mode_options("max-autotune")
177        self.assertEqual(max_autotune_opts["max_autotune"], True)
178        self.assertEqual(max_autotune_opts["triton.cudagraphs"], True)
179
180        max_autotune_opts = torch._inductor.list_mode_options(
181            "max-autotune", dynamic=True
182        )
183        self.assertEqual(max_autotune_opts["max_autotune"], True)
184        self.assertEqual(max_autotune_opts["triton.cudagraphs"], True)
185
186        max_autotune_no_cudagraphs_opts = torch._inductor.list_mode_options(
187            "max-autotune-no-cudagraphs"
188        )
189        self.assertEqual(max_autotune_no_cudagraphs_opts["max_autotune"], True)
190        self.assertEqual(
191            max_autotune_no_cudagraphs_opts.get("triton.cudagraphs", False), False
192        )
193
194    def test_invalid_backend(self):
195        self.assertRaises(
196            torch._dynamo.exc.InvalidBackend,
197            lambda: torch.compile(dummy_fn, backend="does_not_exist")(torch.randn(10)),
198        )
199
200    def test_non_inductor_backend(self):
201        def assert_options(expected_mode=None, expected_options=None):
202            def backend(gm, _, *, mode=None, options=None):
203                nonlocal call_count
204                self.assertEqual(mode, expected_mode)
205                self.assertEqual(options, expected_options)
206                call_count += 1
207                return gm
208
209            return backend
210
211        inp = torch.randn(8)
212
213        def fn(x):
214            return x + 1
215
216        for mode, options in [
217            (None, None),
218            ("fast-mode", None),
219            (None, {"foo": "bar"}),
220        ]:
221            call_count = 0
222            torch.compile(
223                fn, backend=assert_options(mode, options), mode=mode, options=options
224            )(inp)
225            torch._dynamo.reset()
226            self.assertEqual(call_count, 1)
227
228
229if __name__ == "__main__":
230    run_tests()
231