xref: /aosp_15_r20/external/pytorch/test/dynamo/test_config.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: dynamo"]
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerimport torch
4*da0073e9SAndroid Build Coastguard Workerimport torch._dynamo.test_case
5*da0073e9SAndroid Build Coastguard Workerimport torch._dynamo.testing
6*da0073e9SAndroid Build Coastguard Workerfrom torch._dynamo.utils import disable_cache_limit
7*da0073e9SAndroid Build Coastguard Worker
8*da0073e9SAndroid Build Coastguard Worker
9*da0073e9SAndroid Build Coastguard Worker# NB: do NOT include this test class in test_dynamic_shapes.py
10*da0073e9SAndroid Build Coastguard Worker
11*da0073e9SAndroid Build Coastguard Worker
12*da0073e9SAndroid Build Coastguard Workerclass ConfigTests(torch._dynamo.test_case.TestCase):
13*da0073e9SAndroid Build Coastguard Worker    @disable_cache_limit()
14*da0073e9SAndroid Build Coastguard Worker    def test_no_automatic_dynamic(self):
15*da0073e9SAndroid Build Coastguard Worker        def fn(a, b):
16*da0073e9SAndroid Build Coastguard Worker            return a - b * 10
17*da0073e9SAndroid Build Coastguard Worker
18*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.reset()
19*da0073e9SAndroid Build Coastguard Worker        cnt_static = torch._dynamo.testing.CompileCounter()
20*da0073e9SAndroid Build Coastguard Worker        with torch._dynamo.config.patch(
21*da0073e9SAndroid Build Coastguard Worker            automatic_dynamic_shapes=False, assume_static_by_default=True
22*da0073e9SAndroid Build Coastguard Worker        ):
23*da0073e9SAndroid Build Coastguard Worker            opt_fn = torch._dynamo.optimize(cnt_static)(fn)
24*da0073e9SAndroid Build Coastguard Worker            for i in range(2, 12):
25*da0073e9SAndroid Build Coastguard Worker                opt_fn(torch.randn(i), torch.randn(i))
26*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt_static.frame_count, 10)
27*da0073e9SAndroid Build Coastguard Worker
28*da0073e9SAndroid Build Coastguard Worker    @disable_cache_limit()
29*da0073e9SAndroid Build Coastguard Worker    def test_automatic_dynamic(self):
30*da0073e9SAndroid Build Coastguard Worker        def fn(a, b):
31*da0073e9SAndroid Build Coastguard Worker            return a - b * 10
32*da0073e9SAndroid Build Coastguard Worker
33*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.reset()
34*da0073e9SAndroid Build Coastguard Worker        cnt_dynamic = torch._dynamo.testing.CompileCounter()
35*da0073e9SAndroid Build Coastguard Worker        with torch._dynamo.config.patch(
36*da0073e9SAndroid Build Coastguard Worker            automatic_dynamic_shapes=True, assume_static_by_default=True
37*da0073e9SAndroid Build Coastguard Worker        ):
38*da0073e9SAndroid Build Coastguard Worker            opt_fn = torch._dynamo.optimize(cnt_dynamic)(fn)
39*da0073e9SAndroid Build Coastguard Worker            # NB: must not do 0, 1 as they specialized
40*da0073e9SAndroid Build Coastguard Worker            for i in range(2, 12):
41*da0073e9SAndroid Build Coastguard Worker                opt_fn(torch.randn(i), torch.randn(i))
42*da0073e9SAndroid Build Coastguard Worker        # two graphs now rather than 10
43*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt_dynamic.frame_count, 2)
44*da0073e9SAndroid Build Coastguard Worker
45*da0073e9SAndroid Build Coastguard Worker    @disable_cache_limit()
46*da0073e9SAndroid Build Coastguard Worker    def test_no_assume_static_by_default(self):
47*da0073e9SAndroid Build Coastguard Worker        def fn(a, b):
48*da0073e9SAndroid Build Coastguard Worker            return a - b * 10
49*da0073e9SAndroid Build Coastguard Worker
50*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.reset()
51*da0073e9SAndroid Build Coastguard Worker        cnt_dynamic = torch._dynamo.testing.CompileCounter()
52*da0073e9SAndroid Build Coastguard Worker        with torch._dynamo.config.patch(
53*da0073e9SAndroid Build Coastguard Worker            automatic_dynamic_shapes=True, assume_static_by_default=False
54*da0073e9SAndroid Build Coastguard Worker        ):
55*da0073e9SAndroid Build Coastguard Worker            opt_fn = torch._dynamo.optimize(cnt_dynamic)(fn)
56*da0073e9SAndroid Build Coastguard Worker            # NB: must not do 0, 1 as they specialized
57*da0073e9SAndroid Build Coastguard Worker            for i in range(2, 12):
58*da0073e9SAndroid Build Coastguard Worker                opt_fn(torch.randn(i), torch.randn(i))
59*da0073e9SAndroid Build Coastguard Worker        # one graph now, as we didn't wait for recompile
60*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt_dynamic.frame_count, 1)
61*da0073e9SAndroid Build Coastguard Worker
62*da0073e9SAndroid Build Coastguard Worker    def test_config_compile_ignored(self):
63*da0073e9SAndroid Build Coastguard Worker        # Remove from this list if no longer relevant
64*da0073e9SAndroid Build Coastguard Worker        dynamo_guarded_config_ignorelist = {
65*da0073e9SAndroid Build Coastguard Worker            "log_file_name",
66*da0073e9SAndroid Build Coastguard Worker            "verbose",
67*da0073e9SAndroid Build Coastguard Worker            "verify_correctness",  # will not affect model, will raise RuntimeError
68*da0073e9SAndroid Build Coastguard Worker            # (no silent change to compilation behaviour)
69*da0073e9SAndroid Build Coastguard Worker            "cache_size_limit",
70*da0073e9SAndroid Build Coastguard Worker            "accumulated_cache_size_limit",
71*da0073e9SAndroid Build Coastguard Worker            "replay_record_enabled",
72*da0073e9SAndroid Build Coastguard Worker            "cprofile",  # only wraps _compile, not graph
73*da0073e9SAndroid Build Coastguard Worker            "repro_after",
74*da0073e9SAndroid Build Coastguard Worker            "repro_level",
75*da0073e9SAndroid Build Coastguard Worker            "repro_forward_only",
76*da0073e9SAndroid Build Coastguard Worker            "repro_tolerance",
77*da0073e9SAndroid Build Coastguard Worker            "same_two_models_use_fp64",
78*da0073e9SAndroid Build Coastguard Worker            "error_on_recompile",  # safe because: will throw error
79*da0073e9SAndroid Build Coastguard Worker            "report_guard_failures",
80*da0073e9SAndroid Build Coastguard Worker            "base_dir",  # used for minifying / logging
81*da0073e9SAndroid Build Coastguard Worker            "DEBUG_DIR_VAR_NAME",
82*da0073e9SAndroid Build Coastguard Worker            "debug_dir_root",
83*da0073e9SAndroid Build Coastguard Worker        }
84*da0073e9SAndroid Build Coastguard Worker        for k in dynamo_guarded_config_ignorelist:
85*da0073e9SAndroid Build Coastguard Worker            assert k in torch._dynamo.config._compile_ignored_keys, k
86*da0073e9SAndroid Build Coastguard Worker
87*da0073e9SAndroid Build Coastguard Worker    def test_config_hash(self):
88*da0073e9SAndroid Build Coastguard Worker        config = torch._dynamo.config
89*da0073e9SAndroid Build Coastguard Worker        starting_hash = config.get_hash()
90*da0073e9SAndroid Build Coastguard Worker
91*da0073e9SAndroid Build Coastguard Worker        with config.patch({"verbose": not config.verbose}):
92*da0073e9SAndroid Build Coastguard Worker            new_hash = config.get_hash()
93*da0073e9SAndroid Build Coastguard Worker            assert "verbose" in config._compile_ignored_keys
94*da0073e9SAndroid Build Coastguard Worker            assert new_hash == starting_hash
95*da0073e9SAndroid Build Coastguard Worker
96*da0073e9SAndroid Build Coastguard Worker        new_hash = config.get_hash()
97*da0073e9SAndroid Build Coastguard Worker        assert new_hash == starting_hash
98*da0073e9SAndroid Build Coastguard Worker
99*da0073e9SAndroid Build Coastguard Worker        with config.patch({"dead_code_elimination": not config.dead_code_elimination}):
100*da0073e9SAndroid Build Coastguard Worker            changed_hash = config.get_hash()
101*da0073e9SAndroid Build Coastguard Worker            assert "dead_code_elimination" not in config._compile_ignored_keys
102*da0073e9SAndroid Build Coastguard Worker            assert changed_hash != starting_hash
103*da0073e9SAndroid Build Coastguard Worker
104*da0073e9SAndroid Build Coastguard Worker            # Test nested patch
105*da0073e9SAndroid Build Coastguard Worker            with config.patch({"verbose": not config.verbose}):
106*da0073e9SAndroid Build Coastguard Worker                inner_changed_hash = config.get_hash()
107*da0073e9SAndroid Build Coastguard Worker                assert inner_changed_hash == changed_hash
108*da0073e9SAndroid Build Coastguard Worker                assert inner_changed_hash != starting_hash
109*da0073e9SAndroid Build Coastguard Worker
110*da0073e9SAndroid Build Coastguard Worker        newest_hash = config.get_hash()
111*da0073e9SAndroid Build Coastguard Worker        assert changed_hash != newest_hash
112*da0073e9SAndroid Build Coastguard Worker        assert newest_hash == starting_hash
113*da0073e9SAndroid Build Coastguard Worker
114*da0073e9SAndroid Build Coastguard Worker
115*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
116*da0073e9SAndroid Build Coastguard Worker    from torch._dynamo.test_case import run_tests
117*da0073e9SAndroid Build Coastguard Worker
118*da0073e9SAndroid Build Coastguard Worker    run_tests()
119