1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: dynamo"] 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerimport torch 4*da0073e9SAndroid Build Coastguard Workerimport torch._dynamo.test_case 5*da0073e9SAndroid Build Coastguard Workerimport torch._dynamo.testing 6*da0073e9SAndroid Build Coastguard Workerfrom torch._dynamo.utils import disable_cache_limit 7*da0073e9SAndroid Build Coastguard Worker 8*da0073e9SAndroid Build Coastguard Worker 9*da0073e9SAndroid Build Coastguard Worker# NB: do NOT include this test class in test_dynamic_shapes.py 10*da0073e9SAndroid Build Coastguard Worker 11*da0073e9SAndroid Build Coastguard Worker 12*da0073e9SAndroid Build Coastguard Workerclass ConfigTests(torch._dynamo.test_case.TestCase): 13*da0073e9SAndroid Build Coastguard Worker @disable_cache_limit() 14*da0073e9SAndroid Build Coastguard Worker def test_no_automatic_dynamic(self): 15*da0073e9SAndroid Build Coastguard Worker def fn(a, b): 16*da0073e9SAndroid Build Coastguard Worker return a - b * 10 17*da0073e9SAndroid Build Coastguard Worker 18*da0073e9SAndroid Build Coastguard Worker torch._dynamo.reset() 19*da0073e9SAndroid Build Coastguard Worker cnt_static = torch._dynamo.testing.CompileCounter() 20*da0073e9SAndroid Build Coastguard Worker with torch._dynamo.config.patch( 21*da0073e9SAndroid Build Coastguard Worker automatic_dynamic_shapes=False, assume_static_by_default=True 22*da0073e9SAndroid Build Coastguard Worker ): 23*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnt_static)(fn) 24*da0073e9SAndroid Build Coastguard Worker for i in range(2, 12): 25*da0073e9SAndroid Build Coastguard Worker opt_fn(torch.randn(i), torch.randn(i)) 26*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt_static.frame_count, 10) 27*da0073e9SAndroid Build Coastguard Worker 28*da0073e9SAndroid Build Coastguard Worker @disable_cache_limit() 29*da0073e9SAndroid Build Coastguard Worker def test_automatic_dynamic(self): 30*da0073e9SAndroid Build Coastguard Worker def fn(a, b): 31*da0073e9SAndroid Build Coastguard Worker return a - b * 10 32*da0073e9SAndroid Build Coastguard Worker 33*da0073e9SAndroid Build Coastguard Worker torch._dynamo.reset() 34*da0073e9SAndroid Build Coastguard Worker cnt_dynamic = torch._dynamo.testing.CompileCounter() 35*da0073e9SAndroid Build Coastguard Worker with torch._dynamo.config.patch( 36*da0073e9SAndroid Build Coastguard Worker automatic_dynamic_shapes=True, assume_static_by_default=True 37*da0073e9SAndroid Build Coastguard Worker ): 38*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnt_dynamic)(fn) 39*da0073e9SAndroid Build Coastguard Worker # NB: must not do 0, 1 as they specialized 40*da0073e9SAndroid Build Coastguard Worker for i in range(2, 12): 41*da0073e9SAndroid Build Coastguard Worker opt_fn(torch.randn(i), torch.randn(i)) 42*da0073e9SAndroid Build Coastguard Worker # two graphs now rather than 10 43*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt_dynamic.frame_count, 2) 44*da0073e9SAndroid Build Coastguard Worker 45*da0073e9SAndroid Build Coastguard Worker @disable_cache_limit() 46*da0073e9SAndroid Build Coastguard Worker def test_no_assume_static_by_default(self): 47*da0073e9SAndroid Build Coastguard Worker def fn(a, b): 48*da0073e9SAndroid Build Coastguard Worker return a - b * 10 49*da0073e9SAndroid Build Coastguard Worker 50*da0073e9SAndroid Build Coastguard Worker torch._dynamo.reset() 51*da0073e9SAndroid Build Coastguard Worker cnt_dynamic = torch._dynamo.testing.CompileCounter() 52*da0073e9SAndroid Build Coastguard Worker with torch._dynamo.config.patch( 53*da0073e9SAndroid Build Coastguard Worker automatic_dynamic_shapes=True, assume_static_by_default=False 54*da0073e9SAndroid Build Coastguard Worker ): 55*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnt_dynamic)(fn) 56*da0073e9SAndroid Build Coastguard Worker # NB: must not do 0, 1 as they specialized 57*da0073e9SAndroid Build Coastguard Worker for i in range(2, 12): 58*da0073e9SAndroid Build Coastguard Worker opt_fn(torch.randn(i), torch.randn(i)) 59*da0073e9SAndroid Build Coastguard Worker # one graph now, as we didn't wait for recompile 60*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt_dynamic.frame_count, 1) 61*da0073e9SAndroid Build Coastguard Worker 62*da0073e9SAndroid Build Coastguard Worker def test_config_compile_ignored(self): 63*da0073e9SAndroid Build Coastguard Worker # Remove from this list if no longer relevant 64*da0073e9SAndroid Build Coastguard Worker dynamo_guarded_config_ignorelist = { 65*da0073e9SAndroid Build Coastguard Worker "log_file_name", 66*da0073e9SAndroid Build Coastguard Worker "verbose", 67*da0073e9SAndroid Build Coastguard Worker "verify_correctness", # will not affect model, will raise RuntimeError 68*da0073e9SAndroid Build Coastguard Worker # (no silent change to compilation behaviour) 69*da0073e9SAndroid Build Coastguard Worker "cache_size_limit", 70*da0073e9SAndroid Build Coastguard Worker "accumulated_cache_size_limit", 71*da0073e9SAndroid Build Coastguard Worker "replay_record_enabled", 72*da0073e9SAndroid Build Coastguard Worker "cprofile", # only wraps _compile, not graph 73*da0073e9SAndroid Build Coastguard Worker "repro_after", 74*da0073e9SAndroid Build Coastguard Worker "repro_level", 75*da0073e9SAndroid Build Coastguard Worker "repro_forward_only", 76*da0073e9SAndroid Build Coastguard Worker "repro_tolerance", 77*da0073e9SAndroid Build Coastguard Worker "same_two_models_use_fp64", 78*da0073e9SAndroid Build Coastguard Worker "error_on_recompile", # safe because: will throw error 79*da0073e9SAndroid Build Coastguard Worker "report_guard_failures", 80*da0073e9SAndroid Build Coastguard Worker "base_dir", # used for minifying / logging 81*da0073e9SAndroid Build Coastguard Worker "DEBUG_DIR_VAR_NAME", 82*da0073e9SAndroid Build Coastguard Worker "debug_dir_root", 83*da0073e9SAndroid Build Coastguard Worker } 84*da0073e9SAndroid Build Coastguard Worker for k in dynamo_guarded_config_ignorelist: 85*da0073e9SAndroid Build Coastguard Worker assert k in torch._dynamo.config._compile_ignored_keys, k 86*da0073e9SAndroid Build Coastguard Worker 87*da0073e9SAndroid Build Coastguard Worker def test_config_hash(self): 88*da0073e9SAndroid Build Coastguard Worker config = torch._dynamo.config 89*da0073e9SAndroid Build Coastguard Worker starting_hash = config.get_hash() 90*da0073e9SAndroid Build Coastguard Worker 91*da0073e9SAndroid Build Coastguard Worker with config.patch({"verbose": not config.verbose}): 92*da0073e9SAndroid Build Coastguard Worker new_hash = config.get_hash() 93*da0073e9SAndroid Build Coastguard Worker assert "verbose" in config._compile_ignored_keys 94*da0073e9SAndroid Build Coastguard Worker assert new_hash == starting_hash 95*da0073e9SAndroid Build Coastguard Worker 96*da0073e9SAndroid Build Coastguard Worker new_hash = config.get_hash() 97*da0073e9SAndroid Build Coastguard Worker assert new_hash == starting_hash 98*da0073e9SAndroid Build Coastguard Worker 99*da0073e9SAndroid Build Coastguard Worker with config.patch({"dead_code_elimination": not config.dead_code_elimination}): 100*da0073e9SAndroid Build Coastguard Worker changed_hash = config.get_hash() 101*da0073e9SAndroid Build Coastguard Worker assert "dead_code_elimination" not in config._compile_ignored_keys 102*da0073e9SAndroid Build Coastguard Worker assert changed_hash != starting_hash 103*da0073e9SAndroid Build Coastguard Worker 104*da0073e9SAndroid Build Coastguard Worker # Test nested patch 105*da0073e9SAndroid Build Coastguard Worker with config.patch({"verbose": not config.verbose}): 106*da0073e9SAndroid Build Coastguard Worker inner_changed_hash = config.get_hash() 107*da0073e9SAndroid Build Coastguard Worker assert inner_changed_hash == changed_hash 108*da0073e9SAndroid Build Coastguard Worker assert inner_changed_hash != starting_hash 109*da0073e9SAndroid Build Coastguard Worker 110*da0073e9SAndroid Build Coastguard Worker newest_hash = config.get_hash() 111*da0073e9SAndroid Build Coastguard Worker assert changed_hash != newest_hash 112*da0073e9SAndroid Build Coastguard Worker assert newest_hash == starting_hash 113*da0073e9SAndroid Build Coastguard Worker 114*da0073e9SAndroid Build Coastguard Worker 115*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 116*da0073e9SAndroid Build Coastguard Worker from torch._dynamo.test_case import run_tests 117*da0073e9SAndroid Build Coastguard Worker 118*da0073e9SAndroid Build Coastguard Worker run_tests() 119