1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: dynamo"] 2*da0073e9SAndroid Build Coastguard Workerimport contextlib 3*da0073e9SAndroid Build Coastguard Workerimport functools 4*da0073e9SAndroid Build Coastguard Workerimport logging 5*da0073e9SAndroid Build Coastguard Workerimport os 6*da0073e9SAndroid Build Coastguard Workerimport unittest.mock 7*da0073e9SAndroid Build Coastguard Worker 8*da0073e9SAndroid Build Coastguard Workerimport torch 9*da0073e9SAndroid Build Coastguard Workerimport torch._dynamo.test_case 10*da0073e9SAndroid Build Coastguard Workerimport torch._dynamo.testing 11*da0073e9SAndroid Build Coastguard Workerimport torch.distributed as dist 12*da0073e9SAndroid Build Coastguard Workerfrom torch._dynamo.testing import empty_line_normalizer, skipIfNotPy311 13*da0073e9SAndroid Build Coastguard Workerfrom torch._dynamo.trace_rules import _as_posix_path 14*da0073e9SAndroid Build Coastguard Workerfrom torch.nn.parallel import DistributedDataParallel as DDP 15*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import ( 16*da0073e9SAndroid Build Coastguard Worker find_free_port, 17*da0073e9SAndroid Build Coastguard Worker munge_exc, 18*da0073e9SAndroid Build Coastguard Worker skipIfTorchDynamo, 19*da0073e9SAndroid Build Coastguard Worker) 20*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.inductor_utils import HAS_CUDA 21*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.logging_utils import ( 22*da0073e9SAndroid Build Coastguard Worker LoggingTestCase, 23*da0073e9SAndroid Build Coastguard Worker make_logging_test, 24*da0073e9SAndroid Build Coastguard Worker make_settings_test, 25*da0073e9SAndroid Build Coastguard Worker) 26*da0073e9SAndroid Build Coastguard Worker 27*da0073e9SAndroid Build Coastguard Worker 28*da0073e9SAndroid Build Coastguard Workerrequires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda") 29*da0073e9SAndroid Build Coastguard Workerrequires_distributed = functools.partial( 30*da0073e9SAndroid Build Coastguard Worker unittest.skipIf, not dist.is_available(), "requires distributed" 31*da0073e9SAndroid Build Coastguard Worker) 32*da0073e9SAndroid Build Coastguard Worker 33*da0073e9SAndroid Build Coastguard Worker 34*da0073e9SAndroid Build Coastguard Workerdef example_fn(a): 35*da0073e9SAndroid Build Coastguard Worker output = a.mul(torch.ones(1000, 1000)) 36*da0073e9SAndroid Build Coastguard Worker output = output.add(torch.ones(1000, 1000)) 37*da0073e9SAndroid Build Coastguard Worker return output 38*da0073e9SAndroid Build Coastguard Worker 39*da0073e9SAndroid Build Coastguard Worker 40*da0073e9SAndroid Build Coastguard Workerdef dynamo_error_fn(a): 41*da0073e9SAndroid Build Coastguard Worker output = a.mul(torch.ones(1000, 1000)) 42*da0073e9SAndroid Build Coastguard Worker output = output.add(torch.ones(10, 10)) 43*da0073e9SAndroid Build Coastguard Worker return output 44*da0073e9SAndroid Build Coastguard Worker 45*da0073e9SAndroid Build Coastguard Worker 46*da0073e9SAndroid Build Coastguard Workerdef inductor_error_fn(a): 47*da0073e9SAndroid Build Coastguard Worker output = torch.round(a) 48*da0073e9SAndroid Build Coastguard Worker return output 49*da0073e9SAndroid Build Coastguard Worker 50*da0073e9SAndroid Build Coastguard Worker 51*da0073e9SAndroid Build Coastguard Workerdef inductor_schedule_fn(a): 52*da0073e9SAndroid Build Coastguard Worker output = a.add(torch.ones(1000, 1000, device="cuda")) 53*da0073e9SAndroid Build Coastguard Worker return output 54*da0073e9SAndroid Build Coastguard Worker 55*da0073e9SAndroid Build Coastguard Worker 56*da0073e9SAndroid Build Coastguard WorkerARGS = (torch.ones(1000, 1000, requires_grad=True),) 57*da0073e9SAndroid Build Coastguard Worker 58*da0073e9SAndroid Build Coastguard Worker 59*da0073e9SAndroid Build Coastguard Workerdef multi_record_test(num_records, **kwargs): 60*da0073e9SAndroid Build Coastguard Worker @make_logging_test(**kwargs) 61*da0073e9SAndroid Build Coastguard Worker def fn(self, records): 62*da0073e9SAndroid Build Coastguard Worker fn_opt = torch._dynamo.optimize("inductor")(example_fn) 63*da0073e9SAndroid Build Coastguard Worker fn_opt(*ARGS) 64*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(records), num_records) 65*da0073e9SAndroid Build Coastguard Worker 66*da0073e9SAndroid Build Coastguard Worker return fn 67*da0073e9SAndroid Build Coastguard Worker 68*da0073e9SAndroid Build Coastguard Worker 69*da0073e9SAndroid Build Coastguard Workerdef within_range_record_test(num_records_lower, num_records_higher, **kwargs): 70*da0073e9SAndroid Build Coastguard Worker @make_logging_test(**kwargs) 71*da0073e9SAndroid Build Coastguard Worker def fn(self, records): 72*da0073e9SAndroid Build Coastguard Worker fn_opt = torch._dynamo.optimize("inductor")(example_fn) 73*da0073e9SAndroid Build Coastguard Worker fn_opt(*ARGS) 74*da0073e9SAndroid Build Coastguard Worker self.assertGreaterEqual(len(records), num_records_lower) 75*da0073e9SAndroid Build Coastguard Worker self.assertLessEqual(len(records), num_records_higher) 76*da0073e9SAndroid Build Coastguard Worker 77*da0073e9SAndroid Build Coastguard Worker return fn 78*da0073e9SAndroid Build Coastguard Worker 79*da0073e9SAndroid Build Coastguard Worker 80*da0073e9SAndroid Build Coastguard Workerdef single_record_test(**kwargs): 81*da0073e9SAndroid Build Coastguard Worker return multi_record_test(1, **kwargs) 82*da0073e9SAndroid Build Coastguard Worker 83*da0073e9SAndroid Build Coastguard Worker 84*da0073e9SAndroid Build Coastguard Workerclass LoggingTests(LoggingTestCase): 85*da0073e9SAndroid Build Coastguard Worker test_bytecode = multi_record_test(2, bytecode=True) 86*da0073e9SAndroid Build Coastguard Worker test_output_code = multi_record_test(2, output_code=True) 87*da0073e9SAndroid Build Coastguard Worker test_aot_graphs = multi_record_test(3, aot_graphs=True) 88*da0073e9SAndroid Build Coastguard Worker 89*da0073e9SAndroid Build Coastguard Worker @requires_cuda 90*da0073e9SAndroid Build Coastguard Worker @make_logging_test(schedule=True) 91*da0073e9SAndroid Build Coastguard Worker def test_schedule(self, records): 92*da0073e9SAndroid Build Coastguard Worker fn_opt = torch._dynamo.optimize("inductor")(inductor_schedule_fn) 93*da0073e9SAndroid Build Coastguard Worker fn_opt(torch.ones(1000, 1000, device="cuda")) 94*da0073e9SAndroid Build Coastguard Worker self.assertGreater(len(records), 0) 95*da0073e9SAndroid Build Coastguard Worker self.assertLess(len(records), 5) 96*da0073e9SAndroid Build Coastguard Worker 97*da0073e9SAndroid Build Coastguard Worker @requires_cuda 98*da0073e9SAndroid Build Coastguard Worker @make_logging_test(fusion=True) 99*da0073e9SAndroid Build Coastguard Worker def test_fusion(self, records): 100*da0073e9SAndroid Build Coastguard Worker fn_opt = torch._dynamo.optimize("inductor")(inductor_schedule_fn) 101*da0073e9SAndroid Build Coastguard Worker fn_opt(torch.ones(1000, 1000, device="cuda")) 102*da0073e9SAndroid Build Coastguard Worker self.assertGreater(len(records), 0) 103*da0073e9SAndroid Build Coastguard Worker self.assertLess(len(records), 8) 104*da0073e9SAndroid Build Coastguard Worker 105*da0073e9SAndroid Build Coastguard Worker @requires_cuda 106*da0073e9SAndroid Build Coastguard Worker @make_logging_test(cudagraphs=True) 107*da0073e9SAndroid Build Coastguard Worker def test_cudagraphs(self, records): 108*da0073e9SAndroid Build Coastguard Worker fn_opt = torch.compile(mode="reduce-overhead")(inductor_schedule_fn) 109*da0073e9SAndroid Build Coastguard Worker fn_opt(torch.ones(1000, 1000, device="cuda")) 110*da0073e9SAndroid Build Coastguard Worker self.assertGreater(len(records), 0) 111*da0073e9SAndroid Build Coastguard Worker self.assertLess(len(records), 8) 112*da0073e9SAndroid Build Coastguard Worker 113*da0073e9SAndroid Build Coastguard Worker @make_logging_test(recompiles=True) 114*da0073e9SAndroid Build Coastguard Worker def test_recompiles(self, records): 115*da0073e9SAndroid Build Coastguard Worker def fn(x, y): 116*da0073e9SAndroid Build Coastguard Worker return torch.add(x, y) 117*da0073e9SAndroid Build Coastguard Worker 118*da0073e9SAndroid Build Coastguard Worker fn_opt = torch._dynamo.optimize("inductor")(fn) 119*da0073e9SAndroid Build Coastguard Worker fn_opt(torch.ones(1000, 1000), torch.ones(1000, 1000)) 120*da0073e9SAndroid Build Coastguard Worker fn_opt(torch.ones(1000, 1000), 1) 121*da0073e9SAndroid Build Coastguard Worker self.assertGreater(len(records), 0) 122*da0073e9SAndroid Build Coastguard Worker 123*da0073e9SAndroid Build Coastguard Worker test_dynamo_debug = within_range_record_test(30, 90, dynamo=logging.DEBUG) 124*da0073e9SAndroid Build Coastguard Worker test_dynamo_info = within_range_record_test(2, 10, dynamo=logging.INFO) 125*da0073e9SAndroid Build Coastguard Worker 126*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("too slow") 127*da0073e9SAndroid Build Coastguard Worker @make_logging_test(dynamo=logging.DEBUG) 128*da0073e9SAndroid Build Coastguard Worker def test_dynamo_debug_default_off_artifacts(self, records): 129*da0073e9SAndroid Build Coastguard Worker fn_opt = torch._dynamo.optimize("inductor")(example_fn) 130*da0073e9SAndroid Build Coastguard Worker fn_opt(torch.ones(1000, 1000)) 131*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len([r for r in records if ".__bytecode" in r.name]), 0) 132*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len([r for r in records if ".__output_code" in r.name]), 0) 133*da0073e9SAndroid Build Coastguard Worker 134*da0073e9SAndroid Build Coastguard Worker @make_logging_test() 135*da0073e9SAndroid Build Coastguard Worker def test_dynamo_error(self, records): 136*da0073e9SAndroid Build Coastguard Worker try: 137*da0073e9SAndroid Build Coastguard Worker fn_opt = torch._dynamo.optimize("inductor")(dynamo_error_fn) 138*da0073e9SAndroid Build Coastguard Worker fn_opt(*ARGS) 139*da0073e9SAndroid Build Coastguard Worker except Exception: 140*da0073e9SAndroid Build Coastguard Worker pass 141*da0073e9SAndroid Build Coastguard Worker record = self.getRecord(records, "WON'T CONVERT") 142*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 143*da0073e9SAndroid Build Coastguard Worker munge_exc(record.getMessage()), 144*da0073e9SAndroid Build Coastguard Worker """\ 145*da0073e9SAndroid Build Coastguard WorkerWON'T CONVERT dynamo_error_fn test_logging.py line N 146*da0073e9SAndroid Build Coastguard Workerdue to: 147*da0073e9SAndroid Build Coastguard WorkerTraceback (most recent call last): 148*da0073e9SAndroid Build Coastguard Workertorch._dynamo.exc.TorchRuntimeError: Failed running call_method add(*(FakeTensor(..., size=(1000, 1000), grad_fn=<MulBackward0>), FakeTensor(..., size=(10, 10))), **{}): 149*da0073e9SAndroid Build Coastguard WorkerAttempting to broadcast a dimension of length 10 at -1! Mismatching argument at index 1 had torch.Size([10, 10]); but expected shape should be broadcastable to [1000, 1000] 150*da0073e9SAndroid Build Coastguard Worker 151*da0073e9SAndroid Build Coastguard Workerfrom user code: 152*da0073e9SAndroid Build Coastguard Worker File "test_logging.py", line N, in dynamo_error_fn 153*da0073e9SAndroid Build Coastguard Worker output = output.add(torch.ones(10, 10))""", # noqa: B950 154*da0073e9SAndroid Build Coastguard Worker ) 155*da0073e9SAndroid Build Coastguard Worker 156*da0073e9SAndroid Build Coastguard Worker test_aot = within_range_record_test(2, 6, aot=logging.INFO) 157*da0073e9SAndroid Build Coastguard Worker test_inductor_debug = within_range_record_test(3, 17, inductor=logging.DEBUG) 158*da0073e9SAndroid Build Coastguard Worker test_inductor_info = within_range_record_test(2, 4, inductor=logging.INFO) 159*da0073e9SAndroid Build Coastguard Worker 160*da0073e9SAndroid Build Coastguard Worker @make_logging_test() 161*da0073e9SAndroid Build Coastguard Worker def test_inductor_error(self, records): 162*da0073e9SAndroid Build Coastguard Worker exitstack = contextlib.ExitStack() 163*da0073e9SAndroid Build Coastguard Worker import torch._inductor.lowering 164*da0073e9SAndroid Build Coastguard Worker 165*da0073e9SAndroid Build Coastguard Worker def throw(x): 166*da0073e9SAndroid Build Coastguard Worker raise AssertionError 167*da0073e9SAndroid Build Coastguard Worker 168*da0073e9SAndroid Build Coastguard Worker # inject an error in the lowerings 169*da0073e9SAndroid Build Coastguard Worker dict_entries = {} 170*da0073e9SAndroid Build Coastguard Worker for x in list(torch._inductor.lowering.lowerings.keys()): 171*da0073e9SAndroid Build Coastguard Worker if "round" in x.__name__: 172*da0073e9SAndroid Build Coastguard Worker dict_entries[x] = throw 173*da0073e9SAndroid Build Coastguard Worker 174*da0073e9SAndroid Build Coastguard Worker exitstack.enter_context( 175*da0073e9SAndroid Build Coastguard Worker unittest.mock.patch.dict(torch._inductor.lowering.lowerings, dict_entries) 176*da0073e9SAndroid Build Coastguard Worker ) 177*da0073e9SAndroid Build Coastguard Worker 178*da0073e9SAndroid Build Coastguard Worker try: 179*da0073e9SAndroid Build Coastguard Worker fn_opt = torch._dynamo.optimize("inductor")(inductor_error_fn) 180*da0073e9SAndroid Build Coastguard Worker fn_opt(*ARGS) 181*da0073e9SAndroid Build Coastguard Worker except Exception: 182*da0073e9SAndroid Build Coastguard Worker pass 183*da0073e9SAndroid Build Coastguard Worker record = self.getRecord(records, "WON'T CONVERT") 184*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 185*da0073e9SAndroid Build Coastguard Worker munge_exc(record.getMessage()), 186*da0073e9SAndroid Build Coastguard Worker """\ 187*da0073e9SAndroid Build Coastguard WorkerWON'T CONVERT inductor_error_fn test_logging.py line N 188*da0073e9SAndroid Build Coastguard Workerdue to: 189*da0073e9SAndroid Build Coastguard WorkerTraceback (most recent call last): 190*da0073e9SAndroid Build Coastguard Worker File "test_logging.py", line N, in throw 191*da0073e9SAndroid Build Coastguard Worker raise AssertionError 192*da0073e9SAndroid Build Coastguard Workertorch._inductor.exc.LoweringException: AssertionError: 193*da0073e9SAndroid Build Coastguard Worker target: aten.round.default 194*da0073e9SAndroid Build Coastguard Worker args[0]: TensorBox(StorageBox( 195*da0073e9SAndroid Build Coastguard Worker InputBuffer(name='primals_1', layout=FixedLayout('cpu', torch.float32, size=[1000, 1000], stride=[1000, 1])) 196*da0073e9SAndroid Build Coastguard Worker )) 197*da0073e9SAndroid Build Coastguard Worker 198*da0073e9SAndroid Build Coastguard WorkerThe above exception was the direct cause of the following exception: 199*da0073e9SAndroid Build Coastguard Worker 200*da0073e9SAndroid Build Coastguard WorkerTraceback (most recent call last): 201*da0073e9SAndroid Build Coastguard Workertorch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised: 202*da0073e9SAndroid Build Coastguard WorkerLoweringException: AssertionError: 203*da0073e9SAndroid Build Coastguard Worker target: aten.round.default 204*da0073e9SAndroid Build Coastguard Worker args[0]: TensorBox(StorageBox( 205*da0073e9SAndroid Build Coastguard Worker InputBuffer(name='primals_1', layout=FixedLayout('cpu', torch.float32, size=[1000, 1000], stride=[1000, 1])) 206*da0073e9SAndroid Build Coastguard Worker ))""", 207*da0073e9SAndroid Build Coastguard Worker ) 208*da0073e9SAndroid Build Coastguard Worker 209*da0073e9SAndroid Build Coastguard Worker exitstack.close() 210*da0073e9SAndroid Build Coastguard Worker 211*da0073e9SAndroid Build Coastguard Worker @requires_distributed() 212*da0073e9SAndroid Build Coastguard Worker @requires_cuda 213*da0073e9SAndroid Build Coastguard Worker @make_logging_test(ddp_graphs=True) 214*da0073e9SAndroid Build Coastguard Worker def test_ddp_graphs(self, records): 215*da0073e9SAndroid Build Coastguard Worker class ToyModel(torch.nn.Module): 216*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 217*da0073e9SAndroid Build Coastguard Worker super().__init__() 218*da0073e9SAndroid Build Coastguard Worker self.layers = torch.nn.Sequential( 219*da0073e9SAndroid Build Coastguard Worker torch.nn.Linear(1024, 1024), 220*da0073e9SAndroid Build Coastguard Worker torch.nn.Linear(1024, 1024), 221*da0073e9SAndroid Build Coastguard Worker ) 222*da0073e9SAndroid Build Coastguard Worker 223*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 224*da0073e9SAndroid Build Coastguard Worker return self.layers(x) 225*da0073e9SAndroid Build Coastguard Worker 226*da0073e9SAndroid Build Coastguard Worker os.environ["MASTER_ADDR"] = "localhost" 227*da0073e9SAndroid Build Coastguard Worker os.environ["MASTER_PORT"] = str(find_free_port()) 228*da0073e9SAndroid Build Coastguard Worker dist.init_process_group("gloo", rank=0, world_size=1) 229*da0073e9SAndroid Build Coastguard Worker 230*da0073e9SAndroid Build Coastguard Worker ddp_model = torch._dynamo.optimize("inductor")( 231*da0073e9SAndroid Build Coastguard Worker DDP(ToyModel().to("cuda:0"), device_ids=[0], bucket_cap_mb=4) 232*da0073e9SAndroid Build Coastguard Worker ) 233*da0073e9SAndroid Build Coastguard Worker 234*da0073e9SAndroid Build Coastguard Worker ddp_model(torch.randn(1024, 1024, device="cuda:0")) 235*da0073e9SAndroid Build Coastguard Worker 236*da0073e9SAndroid Build Coastguard Worker dist.destroy_process_group() 237*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len([r for r in records if "__ddp_graphs" in r.name]), 4) 238*da0073e9SAndroid Build Coastguard Worker 239*da0073e9SAndroid Build Coastguard Worker # check that logging to a child log of a registered logger 240*da0073e9SAndroid Build Coastguard Worker # does not register it and result in duplicated records 241*da0073e9SAndroid Build Coastguard Worker @make_settings_test("torch._dynamo.output_graph") 242*da0073e9SAndroid Build Coastguard Worker def test_open_registration_with_registered_parent(self, records): 243*da0073e9SAndroid Build Coastguard Worker logger = logging.getLogger("torch._dynamo.output_graph") 244*da0073e9SAndroid Build Coastguard Worker logger.info("hi") 245*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(records), 1) 246*da0073e9SAndroid Build Coastguard Worker 247*da0073e9SAndroid Build Coastguard Worker # check logging to a random log that is not a child log of a registered 248*da0073e9SAndroid Build Coastguard Worker # logger registers it and sets handlers properly 249*da0073e9SAndroid Build Coastguard Worker @make_settings_test("torch.utils") 250*da0073e9SAndroid Build Coastguard Worker def test_open_registration(self, records): 251*da0073e9SAndroid Build Coastguard Worker logger = logging.getLogger("torch.utils") 252*da0073e9SAndroid Build Coastguard Worker logger.info("hi") 253*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(records), 1) 254*da0073e9SAndroid Build Coastguard Worker 255*da0073e9SAndroid Build Coastguard Worker # check logging to a random log that is not a child log of a registered 256*da0073e9SAndroid Build Coastguard Worker # logger registers it and sets handlers properly 257*da0073e9SAndroid Build Coastguard Worker @make_logging_test(modules={"torch.utils": logging.INFO}) 258*da0073e9SAndroid Build Coastguard Worker def test_open_registration_python_api(self, records): 259*da0073e9SAndroid Build Coastguard Worker logger = logging.getLogger("torch.utils") 260*da0073e9SAndroid Build Coastguard Worker logger.info("hi") 261*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(records), 1) 262*da0073e9SAndroid Build Coastguard Worker 263*da0073e9SAndroid Build Coastguard Worker @make_logging_test(all=logging.DEBUG, dynamo=logging.INFO) 264*da0073e9SAndroid Build Coastguard Worker def test_all(self, _): 265*da0073e9SAndroid Build Coastguard Worker registry = torch._logging._internal.log_registry 266*da0073e9SAndroid Build Coastguard Worker 267*da0073e9SAndroid Build Coastguard Worker dynamo_qnames = registry.log_alias_to_log_qnames["dynamo"] 268*da0073e9SAndroid Build Coastguard Worker for logger_qname in torch._logging._internal.log_registry.get_log_qnames(): 269*da0073e9SAndroid Build Coastguard Worker logger = logging.getLogger(logger_qname) 270*da0073e9SAndroid Build Coastguard Worker 271*da0073e9SAndroid Build Coastguard Worker # if logger_qname is a.b.c and dynamo_qnames contains a.b, it still matches dynamo's INFO setting 272*da0073e9SAndroid Build Coastguard Worker if any(logger_qname.find(d) == 0 for d in dynamo_qnames): 273*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 274*da0073e9SAndroid Build Coastguard Worker logger.getEffectiveLevel(), 275*da0073e9SAndroid Build Coastguard Worker logging.INFO, 276*da0073e9SAndroid Build Coastguard Worker msg=f"expected {logger_qname} is INFO, got {logging.getLevelName(logger.getEffectiveLevel())}", 277*da0073e9SAndroid Build Coastguard Worker ) 278*da0073e9SAndroid Build Coastguard Worker else: 279*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 280*da0073e9SAndroid Build Coastguard Worker logger.getEffectiveLevel(), 281*da0073e9SAndroid Build Coastguard Worker logging.DEBUG, 282*da0073e9SAndroid Build Coastguard Worker msg=f"expected {logger_qname} is DEBUG, got {logging.getLevelName(logger.getEffectiveLevel())}", 283*da0073e9SAndroid Build Coastguard Worker ) 284*da0073e9SAndroid Build Coastguard Worker 285*da0073e9SAndroid Build Coastguard Worker @make_logging_test(graph_breaks=True) 286*da0073e9SAndroid Build Coastguard Worker def test_graph_breaks(self, records): 287*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.optimize("inductor") 288*da0073e9SAndroid Build Coastguard Worker def fn(x): 289*da0073e9SAndroid Build Coastguard Worker torch._dynamo.graph_break() 290*da0073e9SAndroid Build Coastguard Worker return x + 1 291*da0073e9SAndroid Build Coastguard Worker 292*da0073e9SAndroid Build Coastguard Worker fn(torch.ones(1)) 293*da0073e9SAndroid Build Coastguard Worker 294*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(records), 1) 295*da0073e9SAndroid Build Coastguard Worker 296*da0073e9SAndroid Build Coastguard Worker @make_settings_test("torch._dynamo.utils") 297*da0073e9SAndroid Build Coastguard Worker def test_dump_compile_times(self, records): 298*da0073e9SAndroid Build Coastguard Worker fn_opt = torch._dynamo.optimize("inductor")(example_fn) 299*da0073e9SAndroid Build Coastguard Worker fn_opt(torch.ones(1000, 1000)) 300*da0073e9SAndroid Build Coastguard Worker # This function runs during exit via atexit.register. 301*da0073e9SAndroid Build Coastguard Worker # We're not actually going to run atexit._run_exit_funcs() here, 302*da0073e9SAndroid Build Coastguard Worker # because it'll destroy state necessary for other tests. 303*da0073e9SAndroid Build Coastguard Worker torch._dynamo.utils.dump_compile_times() 304*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 305*da0073e9SAndroid Build Coastguard Worker len( 306*da0073e9SAndroid Build Coastguard Worker [r for r in records if "TorchDynamo compilation metrics" in str(r.msg)] 307*da0073e9SAndroid Build Coastguard Worker ), 308*da0073e9SAndroid Build Coastguard Worker 1, 309*da0073e9SAndroid Build Coastguard Worker ) 310*da0073e9SAndroid Build Coastguard Worker 311*da0073e9SAndroid Build Coastguard Worker @make_logging_test(dynamo=logging.INFO) 312*da0073e9SAndroid Build Coastguard Worker def test_custom_format_exc(self, records): 313*da0073e9SAndroid Build Coastguard Worker dynamo_log = logging.getLogger(torch._dynamo.__name__) 314*da0073e9SAndroid Build Coastguard Worker try: 315*da0073e9SAndroid Build Coastguard Worker raise RuntimeError("foo") 316*da0073e9SAndroid Build Coastguard Worker except RuntimeError: 317*da0073e9SAndroid Build Coastguard Worker dynamo_log.exception("test dynamo") 318*da0073e9SAndroid Build Coastguard Worker dynamo_log.info("with exc", exc_info=True) 319*da0073e9SAndroid Build Coastguard Worker dynamo_log.info("with stack", stack_info=True) 320*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(records), 3) 321*da0073e9SAndroid Build Coastguard Worker # unfortunately there's no easy way to test the final formatted log other than 322*da0073e9SAndroid Build Coastguard Worker # to ask the dynamo logger's handler to format it. 323*da0073e9SAndroid Build Coastguard Worker for handler in dynamo_log.handlers: 324*da0073e9SAndroid Build Coastguard Worker if torch._logging._internal._is_torch_handler(handler): 325*da0073e9SAndroid Build Coastguard Worker break 326*da0073e9SAndroid Build Coastguard Worker self.assertIsNotNone(handler) 327*da0073e9SAndroid Build Coastguard Worker self.assertIn("Traceback", handler.format(records[0])) 328*da0073e9SAndroid Build Coastguard Worker self.assertIn("Traceback", handler.format(records[1])) 329*da0073e9SAndroid Build Coastguard Worker self.assertIn("Stack", handler.format(records[2])) 330*da0073e9SAndroid Build Coastguard Worker 331*da0073e9SAndroid Build Coastguard Worker @make_logging_test(dynamo=logging.INFO) 332*da0073e9SAndroid Build Coastguard Worker def test_custom_format(self, records): 333*da0073e9SAndroid Build Coastguard Worker dynamo_log = logging.getLogger(torch._dynamo.__name__) 334*da0073e9SAndroid Build Coastguard Worker test_log = torch._logging.getArtifactLogger( 335*da0073e9SAndroid Build Coastguard Worker torch._dynamo.__name__, "custom_format_test_artifact" 336*da0073e9SAndroid Build Coastguard Worker ) 337*da0073e9SAndroid Build Coastguard Worker dynamo_log.info("test dynamo") 338*da0073e9SAndroid Build Coastguard Worker test_log.info("custom format") 339*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(records), 2) 340*da0073e9SAndroid Build Coastguard Worker # unfortunately there's no easy way to test the final formatted log other than 341*da0073e9SAndroid Build Coastguard Worker # to ask the dynamo logger's handler to format it. 342*da0073e9SAndroid Build Coastguard Worker for handler in dynamo_log.handlers: 343*da0073e9SAndroid Build Coastguard Worker if torch._logging._internal._is_torch_handler(handler): 344*da0073e9SAndroid Build Coastguard Worker break 345*da0073e9SAndroid Build Coastguard Worker self.assertIsNotNone(handler) 346*da0073e9SAndroid Build Coastguard Worker self.assertIn("I", handler.format(records[0])) 347*da0073e9SAndroid Build Coastguard Worker self.assertEqual("custom format", handler.format(records[1])) 348*da0073e9SAndroid Build Coastguard Worker 349*da0073e9SAndroid Build Coastguard Worker @make_logging_test(dynamo=logging.INFO) 350*da0073e9SAndroid Build Coastguard Worker def test_multiline_format(self, records): 351*da0073e9SAndroid Build Coastguard Worker dynamo_log = logging.getLogger(torch._dynamo.__name__) 352*da0073e9SAndroid Build Coastguard Worker dynamo_log.info("test\ndynamo") 353*da0073e9SAndroid Build Coastguard Worker dynamo_log.info("%s", "test\ndynamo") 354*da0073e9SAndroid Build Coastguard Worker dynamo_log.info("test\n%s", "test\ndynamo") 355*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(records), 3) 356*da0073e9SAndroid Build Coastguard Worker # unfortunately there's no easy way to test the final formatted log other than 357*da0073e9SAndroid Build Coastguard Worker # to ask the dynamo logger's handler to format it. 358*da0073e9SAndroid Build Coastguard Worker for handler in dynamo_log.handlers: 359*da0073e9SAndroid Build Coastguard Worker if torch._logging._internal._is_torch_handler(handler): 360*da0073e9SAndroid Build Coastguard Worker break 361*da0073e9SAndroid Build Coastguard Worker self.assertIsNotNone(handler) 362*da0073e9SAndroid Build Coastguard Worker for record in records: 363*da0073e9SAndroid Build Coastguard Worker r = handler.format(record) 364*da0073e9SAndroid Build Coastguard Worker for l in r.splitlines(): 365*da0073e9SAndroid Build Coastguard Worker self.assertIn("I", l) 366*da0073e9SAndroid Build Coastguard Worker 367*da0073e9SAndroid Build Coastguard Worker test_trace_source_simple = within_range_record_test(1, 100, trace_source=True) 368*da0073e9SAndroid Build Coastguard Worker 369*da0073e9SAndroid Build Coastguard Worker @make_logging_test(trace_source=True) 370*da0073e9SAndroid Build Coastguard Worker def test_trace_source_if_stmt(self, records): 371*da0073e9SAndroid Build Coastguard Worker def fn(x): 372*da0073e9SAndroid Build Coastguard Worker if x.sum() > 0: 373*da0073e9SAndroid Build Coastguard Worker return x * 2 374*da0073e9SAndroid Build Coastguard Worker return x * 3 375*da0073e9SAndroid Build Coastguard Worker 376*da0073e9SAndroid Build Coastguard Worker fn_opt = torch._dynamo.optimize("eager")(fn) 377*da0073e9SAndroid Build Coastguard Worker fn_opt(torch.ones(3, 3)) 378*da0073e9SAndroid Build Coastguard Worker 379*da0073e9SAndroid Build Coastguard Worker found_x2 = False 380*da0073e9SAndroid Build Coastguard Worker found_x3 = False 381*da0073e9SAndroid Build Coastguard Worker for record in records: 382*da0073e9SAndroid Build Coastguard Worker msg = record.getMessage() 383*da0073e9SAndroid Build Coastguard Worker if "return x * 2" in msg: 384*da0073e9SAndroid Build Coastguard Worker found_x2 = True 385*da0073e9SAndroid Build Coastguard Worker if "return x * 3" in msg: 386*da0073e9SAndroid Build Coastguard Worker found_x3 = True 387*da0073e9SAndroid Build Coastguard Worker 388*da0073e9SAndroid Build Coastguard Worker self.assertTrue(found_x2) 389*da0073e9SAndroid Build Coastguard Worker self.assertFalse(found_x3) 390*da0073e9SAndroid Build Coastguard Worker 391*da0073e9SAndroid Build Coastguard Worker @make_logging_test(trace_source=True) 392*da0073e9SAndroid Build Coastguard Worker def test_trace_source_nested(self, records): 393*da0073e9SAndroid Build Coastguard Worker def fn1(x): 394*da0073e9SAndroid Build Coastguard Worker x = fn2(x) 395*da0073e9SAndroid Build Coastguard Worker return x * 2 396*da0073e9SAndroid Build Coastguard Worker 397*da0073e9SAndroid Build Coastguard Worker def fn2(x): 398*da0073e9SAndroid Build Coastguard Worker x = fn3(x) 399*da0073e9SAndroid Build Coastguard Worker return x * 3 400*da0073e9SAndroid Build Coastguard Worker 401*da0073e9SAndroid Build Coastguard Worker def fn3(x): 402*da0073e9SAndroid Build Coastguard Worker return x * 4 403*da0073e9SAndroid Build Coastguard Worker 404*da0073e9SAndroid Build Coastguard Worker fn_opt = torch._dynamo.optimize("eager")(fn1) 405*da0073e9SAndroid Build Coastguard Worker fn_opt(torch.ones(3, 3)) 406*da0073e9SAndroid Build Coastguard Worker 407*da0073e9SAndroid Build Coastguard Worker found_x2 = False 408*da0073e9SAndroid Build Coastguard Worker found_x3 = False 409*da0073e9SAndroid Build Coastguard Worker found_x4 = False 410*da0073e9SAndroid Build Coastguard Worker for record in records: 411*da0073e9SAndroid Build Coastguard Worker msg = record.getMessage() 412*da0073e9SAndroid Build Coastguard Worker if "return x * 2" in msg: 413*da0073e9SAndroid Build Coastguard Worker found_x2 = True 414*da0073e9SAndroid Build Coastguard Worker self.assertNotIn("inline depth", msg) 415*da0073e9SAndroid Build Coastguard Worker elif "return x * 3" in msg: 416*da0073e9SAndroid Build Coastguard Worker found_x3 = True 417*da0073e9SAndroid Build Coastguard Worker self.assertIn("inline depth: 1", msg) 418*da0073e9SAndroid Build Coastguard Worker elif "return x * 4" in msg: 419*da0073e9SAndroid Build Coastguard Worker found_x4 = True 420*da0073e9SAndroid Build Coastguard Worker self.assertIn("inline depth: 2", msg) 421*da0073e9SAndroid Build Coastguard Worker self.assertTrue(found_x2) 422*da0073e9SAndroid Build Coastguard Worker self.assertTrue(found_x3) 423*da0073e9SAndroid Build Coastguard Worker self.assertTrue(found_x4) 424*da0073e9SAndroid Build Coastguard Worker 425*da0073e9SAndroid Build Coastguard Worker @make_logging_test(trace_source=True) 426*da0073e9SAndroid Build Coastguard Worker def test_trace_source_cond(self, records): 427*da0073e9SAndroid Build Coastguard Worker from functorch.experimental.control_flow import cond 428*da0073e9SAndroid Build Coastguard Worker 429*da0073e9SAndroid Build Coastguard Worker def true_fn(x): 430*da0073e9SAndroid Build Coastguard Worker return x * 2 431*da0073e9SAndroid Build Coastguard Worker 432*da0073e9SAndroid Build Coastguard Worker def false_fn(x): 433*da0073e9SAndroid Build Coastguard Worker return x * 3 434*da0073e9SAndroid Build Coastguard Worker 435*da0073e9SAndroid Build Coastguard Worker def inner(pred, x): 436*da0073e9SAndroid Build Coastguard Worker return cond(pred, true_fn, false_fn, [x]) 437*da0073e9SAndroid Build Coastguard Worker 438*da0073e9SAndroid Build Coastguard Worker def outer(pred, x): 439*da0073e9SAndroid Build Coastguard Worker return inner(pred, x) 440*da0073e9SAndroid Build Coastguard Worker 441*da0073e9SAndroid Build Coastguard Worker fn_opt = torch._dynamo.optimize("eager")(outer) 442*da0073e9SAndroid Build Coastguard Worker fn_opt(torch.tensor(True), torch.ones(3, 3)) 443*da0073e9SAndroid Build Coastguard Worker 444*da0073e9SAndroid Build Coastguard Worker found_x2 = False 445*da0073e9SAndroid Build Coastguard Worker found_x3 = False 446*da0073e9SAndroid Build Coastguard Worker for record in records: 447*da0073e9SAndroid Build Coastguard Worker msg = record.getMessage() 448*da0073e9SAndroid Build Coastguard Worker if "return x * 2" in msg: 449*da0073e9SAndroid Build Coastguard Worker found_x2 = True 450*da0073e9SAndroid Build Coastguard Worker self.assertIn("inline depth: 3", msg) 451*da0073e9SAndroid Build Coastguard Worker if "return x * 3" in msg: 452*da0073e9SAndroid Build Coastguard Worker found_x3 = True 453*da0073e9SAndroid Build Coastguard Worker self.assertIn("inline depth: 3", msg) 454*da0073e9SAndroid Build Coastguard Worker 455*da0073e9SAndroid Build Coastguard Worker self.assertTrue(found_x2) 456*da0073e9SAndroid Build Coastguard Worker self.assertTrue(found_x3) 457*da0073e9SAndroid Build Coastguard Worker 458*da0073e9SAndroid Build Coastguard Worker @make_logging_test(trace_source=True) 459*da0073e9SAndroid Build Coastguard Worker def test_trace_source_funcname(self, records): 460*da0073e9SAndroid Build Coastguard Worker # NOTE: list comprehensions are inlined in 3.12, so test with tuples 461*da0073e9SAndroid Build Coastguard Worker def fn1(): 462*da0073e9SAndroid Build Coastguard Worker def fn2(): 463*da0073e9SAndroid Build Coastguard Worker if True: 464*da0073e9SAndroid Build Coastguard Worker return tuple(torch.ones(3, 3) for _ in range(5)) 465*da0073e9SAndroid Build Coastguard Worker return None 466*da0073e9SAndroid Build Coastguard Worker 467*da0073e9SAndroid Build Coastguard Worker return fn2() 468*da0073e9SAndroid Build Coastguard Worker 469*da0073e9SAndroid Build Coastguard Worker fn_opt = torch._dynamo.optimize("eager")(fn1) 470*da0073e9SAndroid Build Coastguard Worker fn_opt() 471*da0073e9SAndroid Build Coastguard Worker 472*da0073e9SAndroid Build Coastguard Worker found_funcname = False 473*da0073e9SAndroid Build Coastguard Worker for record in records: 474*da0073e9SAndroid Build Coastguard Worker msg = record.getMessage() 475*da0073e9SAndroid Build Coastguard Worker if "<genexpr>" in msg and "fn1.fn2" in msg: 476*da0073e9SAndroid Build Coastguard Worker found_funcname = True 477*da0073e9SAndroid Build Coastguard Worker 478*da0073e9SAndroid Build Coastguard Worker self.assertTrue(found_funcname) 479*da0073e9SAndroid Build Coastguard Worker 480*da0073e9SAndroid Build Coastguard Worker def test_invalid_artifact_flag(self): 481*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(ValueError): 482*da0073e9SAndroid Build Coastguard Worker torch._logging.set_logs(aot_graphs=5) 483*da0073e9SAndroid Build Coastguard Worker 484*da0073e9SAndroid Build Coastguard Worker @requires_distributed() 485*da0073e9SAndroid Build Coastguard Worker def test_distributed_rank_logging(self): 486*da0073e9SAndroid Build Coastguard Worker env = dict(os.environ) 487*da0073e9SAndroid Build Coastguard Worker env["TORCH_LOGS"] = "dynamo" 488*da0073e9SAndroid Build Coastguard Worker stdout, stderr = self.run_process_no_exception( 489*da0073e9SAndroid Build Coastguard Worker """\ 490*da0073e9SAndroid Build Coastguard Workerimport torch.distributed as dist 491*da0073e9SAndroid Build Coastguard Workerimport logging 492*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.distributed.fake_pg import FakeStore 493*da0073e9SAndroid Build Coastguard Workerstore = FakeStore() 494*da0073e9SAndroid Build Coastguard Workerdist.init_process_group("fake", rank=0, world_size=2, store=store) 495*da0073e9SAndroid Build Coastguard Workerdynamo_log = logging.getLogger("torch._dynamo") 496*da0073e9SAndroid Build Coastguard Workerdynamo_log.info("woof") 497*da0073e9SAndroid Build Coastguard Workerprint("arf") 498*da0073e9SAndroid Build Coastguard Worker""", 499*da0073e9SAndroid Build Coastguard Worker env=env, 500*da0073e9SAndroid Build Coastguard Worker ) 501*da0073e9SAndroid Build Coastguard Worker self.assertIn("[rank0]:", stderr.decode("utf-8")) 502*da0073e9SAndroid Build Coastguard Worker 503*da0073e9SAndroid Build Coastguard Worker @skipIfNotPy311 504*da0073e9SAndroid Build Coastguard Worker @make_logging_test(trace_call=True) 505*da0073e9SAndroid Build Coastguard Worker def test_trace_call(self, records): 506*da0073e9SAndroid Build Coastguard Worker def fn(x, y): 507*da0073e9SAndroid Build Coastguard Worker return (x * 2) @ (y * 3) 508*da0073e9SAndroid Build Coastguard Worker 509*da0073e9SAndroid Build Coastguard Worker fn_opt = torch._dynamo.optimize("eager")(fn) 510*da0073e9SAndroid Build Coastguard Worker fn_opt(torch.randn(10, 20), torch.randn(20, 30)) 511*da0073e9SAndroid Build Coastguard Worker 512*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(records), 3) 513*da0073e9SAndroid Build Coastguard Worker # only get last 2 lines 514*da0073e9SAndroid Build Coastguard Worker messages = [ 515*da0073e9SAndroid Build Coastguard Worker "\n".join(record.getMessage().split("\n")[-2:]) for record in records 516*da0073e9SAndroid Build Coastguard Worker ] 517*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 518*da0073e9SAndroid Build Coastguard Worker messages[0], 519*da0073e9SAndroid Build Coastguard Worker """\ 520*da0073e9SAndroid Build Coastguard Worker return (x * 2) @ (y * 3) 521*da0073e9SAndroid Build Coastguard Worker ~~^~~""", 522*da0073e9SAndroid Build Coastguard Worker ) 523*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 524*da0073e9SAndroid Build Coastguard Worker messages[1], 525*da0073e9SAndroid Build Coastguard Worker """\ 526*da0073e9SAndroid Build Coastguard Worker return (x * 2) @ (y * 3) 527*da0073e9SAndroid Build Coastguard Worker ~~^~~""", 528*da0073e9SAndroid Build Coastguard Worker ) 529*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 530*da0073e9SAndroid Build Coastguard Worker messages[2], 531*da0073e9SAndroid Build Coastguard Worker """\ 532*da0073e9SAndroid Build Coastguard Worker return (x * 2) @ (y * 3) 533*da0073e9SAndroid Build Coastguard Worker ~~~~~~~~^~~~~~~~~""", 534*da0073e9SAndroid Build Coastguard Worker ) 535*da0073e9SAndroid Build Coastguard Worker 536*da0073e9SAndroid Build Coastguard Worker @skipIfNotPy311 537*da0073e9SAndroid Build Coastguard Worker @make_logging_test(trace_call=True) 538*da0073e9SAndroid Build Coastguard Worker def test_trace_call_inline_call(self, records): 539*da0073e9SAndroid Build Coastguard Worker def g(x): 540*da0073e9SAndroid Build Coastguard Worker return x * 2 541*da0073e9SAndroid Build Coastguard Worker 542*da0073e9SAndroid Build Coastguard Worker def f(x): 543*da0073e9SAndroid Build Coastguard Worker return g(g(x)) 544*da0073e9SAndroid Build Coastguard Worker 545*da0073e9SAndroid Build Coastguard Worker fn_opt = torch._dynamo.optimize("eager")(f) 546*da0073e9SAndroid Build Coastguard Worker fn_opt(torch.randn(3, 3)) 547*da0073e9SAndroid Build Coastguard Worker 548*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(records), 4) 549*da0073e9SAndroid Build Coastguard Worker messages = [ 550*da0073e9SAndroid Build Coastguard Worker "\n".join(record.getMessage().split("\n")[-2:]) for record in records 551*da0073e9SAndroid Build Coastguard Worker ] 552*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 553*da0073e9SAndroid Build Coastguard Worker messages[0], 554*da0073e9SAndroid Build Coastguard Worker """\ 555*da0073e9SAndroid Build Coastguard Worker return g(g(x)) 556*da0073e9SAndroid Build Coastguard Worker ~^^^""", 557*da0073e9SAndroid Build Coastguard Worker ) 558*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 559*da0073e9SAndroid Build Coastguard Worker messages[1], 560*da0073e9SAndroid Build Coastguard Worker """\ 561*da0073e9SAndroid Build Coastguard Worker return x * 2 562*da0073e9SAndroid Build Coastguard Worker ~~^~~""", 563*da0073e9SAndroid Build Coastguard Worker ) 564*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 565*da0073e9SAndroid Build Coastguard Worker messages[2], 566*da0073e9SAndroid Build Coastguard Worker """\ 567*da0073e9SAndroid Build Coastguard Worker return g(g(x)) 568*da0073e9SAndroid Build Coastguard Worker ~^^^^^^""", 569*da0073e9SAndroid Build Coastguard Worker ) 570*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 571*da0073e9SAndroid Build Coastguard Worker messages[3], 572*da0073e9SAndroid Build Coastguard Worker """\ 573*da0073e9SAndroid Build Coastguard Worker return x * 2 574*da0073e9SAndroid Build Coastguard Worker ~~^~~""", 575*da0073e9SAndroid Build Coastguard Worker ) 576*da0073e9SAndroid Build Coastguard Worker 577*da0073e9SAndroid Build Coastguard Worker @skipIfNotPy311 578*da0073e9SAndroid Build Coastguard Worker @make_logging_test(trace_call=True) 579*da0073e9SAndroid Build Coastguard Worker def test_trace_call_graph_break(self, records): 580*da0073e9SAndroid Build Coastguard Worker def fn(x): 581*da0073e9SAndroid Build Coastguard Worker x = x * 2 582*da0073e9SAndroid Build Coastguard Worker torch._dynamo.graph_break() 583*da0073e9SAndroid Build Coastguard Worker return x * 3 584*da0073e9SAndroid Build Coastguard Worker 585*da0073e9SAndroid Build Coastguard Worker fn_opt = torch._dynamo.optimize("eager")(fn) 586*da0073e9SAndroid Build Coastguard Worker fn_opt(torch.randn(3, 3)) 587*da0073e9SAndroid Build Coastguard Worker 588*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(records), 3) 589*da0073e9SAndroid Build Coastguard Worker messages = [ 590*da0073e9SAndroid Build Coastguard Worker "\n".join(record.getMessage().split("\n")[-2:]) for record in records 591*da0073e9SAndroid Build Coastguard Worker ] 592*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 593*da0073e9SAndroid Build Coastguard Worker messages[0], 594*da0073e9SAndroid Build Coastguard Worker """\ 595*da0073e9SAndroid Build Coastguard Worker x = x * 2 596*da0073e9SAndroid Build Coastguard Worker ~~^~~""", 597*da0073e9SAndroid Build Coastguard Worker ) 598*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 599*da0073e9SAndroid Build Coastguard Worker messages[-1], 600*da0073e9SAndroid Build Coastguard Worker """\ 601*da0073e9SAndroid Build Coastguard Worker return x * 3 602*da0073e9SAndroid Build Coastguard Worker ~~^~~""", 603*da0073e9SAndroid Build Coastguard Worker ) 604*da0073e9SAndroid Build Coastguard Worker 605*da0073e9SAndroid Build Coastguard Worker @make_logging_test(guards=True, recompiles=True) 606*da0073e9SAndroid Build Coastguard Worker def test_guards_recompiles(self, records): 607*da0073e9SAndroid Build Coastguard Worker def fn(x, ys, zs): 608*da0073e9SAndroid Build Coastguard Worker return inner(x, ys, zs) 609*da0073e9SAndroid Build Coastguard Worker 610*da0073e9SAndroid Build Coastguard Worker def inner(x, ys, zs): 611*da0073e9SAndroid Build Coastguard Worker for y, z in zip(ys, zs): 612*da0073e9SAndroid Build Coastguard Worker x += y * z 613*da0073e9SAndroid Build Coastguard Worker return x 614*da0073e9SAndroid Build Coastguard Worker 615*da0073e9SAndroid Build Coastguard Worker ys = [1.0, 2.0] 616*da0073e9SAndroid Build Coastguard Worker zs = [3.0] 617*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([1.0]) 618*da0073e9SAndroid Build Coastguard Worker 619*da0073e9SAndroid Build Coastguard Worker fn_opt = torch._dynamo.optimize("eager")(fn) 620*da0073e9SAndroid Build Coastguard Worker fn_opt(x, ys, zs) 621*da0073e9SAndroid Build Coastguard Worker fn_opt(x, ys[:1], zs) 622*da0073e9SAndroid Build Coastguard Worker 623*da0073e9SAndroid Build Coastguard Worker record_str = "\n".join(r.getMessage() for r in records) 624*da0073e9SAndroid Build Coastguard Worker 625*da0073e9SAndroid Build Coastguard Worker self.assertIn( 626*da0073e9SAndroid Build Coastguard Worker """L['zs'][0] == 3.0""", 627*da0073e9SAndroid Build Coastguard Worker record_str, 628*da0073e9SAndroid Build Coastguard Worker ) 629*da0073e9SAndroid Build Coastguard Worker self.assertIn( 630*da0073e9SAndroid Build Coastguard Worker "len(L['ys']) == 2", 631*da0073e9SAndroid Build Coastguard Worker record_str, 632*da0073e9SAndroid Build Coastguard Worker ) 633*da0073e9SAndroid Build Coastguard Worker 634*da0073e9SAndroid Build Coastguard Worker @make_logging_test(cudagraph_static_inputs=True) 635*da0073e9SAndroid Build Coastguard Worker def test_cudagraph_static_inputs(self, records): 636*da0073e9SAndroid Build Coastguard Worker @torch.compile(mode="reduce-overhead") 637*da0073e9SAndroid Build Coastguard Worker def fn(x): 638*da0073e9SAndroid Build Coastguard Worker return x + 1 639*da0073e9SAndroid Build Coastguard Worker 640*da0073e9SAndroid Build Coastguard Worker x = torch.ones(2, 2) 641*da0073e9SAndroid Build Coastguard Worker torch._dynamo.mark_static_address(x) 642*da0073e9SAndroid Build Coastguard Worker fn(x) 643*da0073e9SAndroid Build Coastguard Worker self.assertGreater(len(records), 0) 644*da0073e9SAndroid Build Coastguard Worker self.assertLess(len(records), 4) 645*da0073e9SAndroid Build Coastguard Worker 646*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("too slow") 647*da0073e9SAndroid Build Coastguard Worker @make_logging_test(**torch._logging.DEFAULT_LOGGING) 648*da0073e9SAndroid Build Coastguard Worker def test_default_logging(self, records): 649*da0073e9SAndroid Build Coastguard Worker def fn(a): 650*da0073e9SAndroid Build Coastguard Worker if a.sum() < 0: 651*da0073e9SAndroid Build Coastguard Worker a = torch.sin(a) 652*da0073e9SAndroid Build Coastguard Worker else: 653*da0073e9SAndroid Build Coastguard Worker a = torch.cos(a) 654*da0073e9SAndroid Build Coastguard Worker print("hello") 655*da0073e9SAndroid Build Coastguard Worker return a + 1 656*da0073e9SAndroid Build Coastguard Worker 657*da0073e9SAndroid Build Coastguard Worker fn_opt = torch._dynamo.optimize("eager")(fn) 658*da0073e9SAndroid Build Coastguard Worker fn_opt(torch.ones(10, 10)) 659*da0073e9SAndroid Build Coastguard Worker fn_opt(-torch.ones(10, 5)) 660*da0073e9SAndroid Build Coastguard Worker 661*da0073e9SAndroid Build Coastguard Worker self.assertGreater(len([r for r in records if ".__graph_breaks" in r.name]), 0) 662*da0073e9SAndroid Build Coastguard Worker self.assertGreater(len([r for r in records if ".__recompiles" in r.name]), 0) 663*da0073e9SAndroid Build Coastguard Worker self.assertGreater(len([r for r in records if ".symbolic_shapes" in r.name]), 0) 664*da0073e9SAndroid Build Coastguard Worker self.assertGreater(len([r for r in records if ".__guards" in r.name]), 0) 665*da0073e9SAndroid Build Coastguard Worker self.assertGreater( 666*da0073e9SAndroid Build Coastguard Worker len([r for r in records if "return a + 1" in r.getMessage()]), 0 667*da0073e9SAndroid Build Coastguard Worker ) 668*da0073e9SAndroid Build Coastguard Worker 669*da0073e9SAndroid Build Coastguard Worker def test_logs_out(self): 670*da0073e9SAndroid Build Coastguard Worker import tempfile 671*da0073e9SAndroid Build Coastguard Worker 672*da0073e9SAndroid Build Coastguard Worker with tempfile.NamedTemporaryFile(delete=False) as tmp: 673*da0073e9SAndroid Build Coastguard Worker file_path = _as_posix_path(tmp.name) 674*da0073e9SAndroid Build Coastguard Worker """ 675*da0073e9SAndroid Build Coastguard Worker NamedTemporaryFile will include a file open operation. 676*da0073e9SAndroid Build Coastguard Worker On Windowsm the file is opened by NamedTemporaryFile, the 677*da0073e9SAndroid Build Coastguard Worker following run_process_no_exception can't access a opened file. 678*da0073e9SAndroid Build Coastguard Worker And then, raise a PermissionError: [Errno 13] Permission denied: [file_path] 679*da0073e9SAndroid Build Coastguard Worker """ 680*da0073e9SAndroid Build Coastguard Worker tmp.close() 681*da0073e9SAndroid Build Coastguard Worker env = dict(os.environ) 682*da0073e9SAndroid Build Coastguard Worker env["TORCH_LOGS"] = "dynamo" 683*da0073e9SAndroid Build Coastguard Worker env["TORCH_LOGS_OUT"] = file_path 684*da0073e9SAndroid Build Coastguard Worker stdout, stderr = self.run_process_no_exception( 685*da0073e9SAndroid Build Coastguard Worker """\ 686*da0073e9SAndroid Build Coastguard Workerimport torch 687*da0073e9SAndroid Build Coastguard Worker@torch.compile(backend="eager") 688*da0073e9SAndroid Build Coastguard Workerdef fn(a): 689*da0073e9SAndroid Build Coastguard Worker return a.sum() 690*da0073e9SAndroid Build Coastguard Worker 691*da0073e9SAndroid Build Coastguard Workerfn(torch.randn(5)) 692*da0073e9SAndroid Build Coastguard Worker """, 693*da0073e9SAndroid Build Coastguard Worker env=env, 694*da0073e9SAndroid Build Coastguard Worker ) 695*da0073e9SAndroid Build Coastguard Worker with open( 696*da0073e9SAndroid Build Coastguard Worker file_path, encoding="utf-8" 697*da0073e9SAndroid Build Coastguard Worker ) as fd: # encoding file to UTF-8 for Windows. 698*da0073e9SAndroid Build Coastguard Worker lines = fd.read() 699*da0073e9SAndroid Build Coastguard Worker fd.close() 700*da0073e9SAndroid Build Coastguard Worker os.remove( 701*da0073e9SAndroid Build Coastguard Worker file_path 702*da0073e9SAndroid Build Coastguard Worker ) # Delete temp file manually, due to setup NamedTemporaryFile as delete=False. 703*da0073e9SAndroid Build Coastguard Worker self.assertEqual( # process wrap difference: /r/n on Windows, /n on posix. 704*da0073e9SAndroid Build Coastguard Worker empty_line_normalizer(lines), 705*da0073e9SAndroid Build Coastguard Worker empty_line_normalizer(stderr.decode("utf-8")), 706*da0073e9SAndroid Build Coastguard Worker ) 707*da0073e9SAndroid Build Coastguard Worker 708*da0073e9SAndroid Build Coastguard Worker 709*da0073e9SAndroid Build Coastguard Worker# single record tests 710*da0073e9SAndroid Build Coastguard Workerexclusions = { 711*da0073e9SAndroid Build Coastguard Worker "bytecode", 712*da0073e9SAndroid Build Coastguard Worker "cudagraphs", 713*da0073e9SAndroid Build Coastguard Worker "output_code", 714*da0073e9SAndroid Build Coastguard Worker "schedule", 715*da0073e9SAndroid Build Coastguard Worker "fusion", 716*da0073e9SAndroid Build Coastguard Worker "overlap", 717*da0073e9SAndroid Build Coastguard Worker "aot_graphs", 718*da0073e9SAndroid Build Coastguard Worker "aot_graphs_effects", 719*da0073e9SAndroid Build Coastguard Worker "post_grad_graphs", 720*da0073e9SAndroid Build Coastguard Worker "compiled_autograd", 721*da0073e9SAndroid Build Coastguard Worker "compiled_autograd_verbose", 722*da0073e9SAndroid Build Coastguard Worker "recompiles", 723*da0073e9SAndroid Build Coastguard Worker "recompiles_verbose", 724*da0073e9SAndroid Build Coastguard Worker "graph_breaks", 725*da0073e9SAndroid Build Coastguard Worker "graph", 726*da0073e9SAndroid Build Coastguard Worker "graph_code", 727*da0073e9SAndroid Build Coastguard Worker "graph_sizes", 728*da0073e9SAndroid Build Coastguard Worker "ddp_graphs", 729*da0073e9SAndroid Build Coastguard Worker "perf_hints", 730*da0073e9SAndroid Build Coastguard Worker "not_implemented", 731*da0073e9SAndroid Build Coastguard Worker "trace_source", 732*da0073e9SAndroid Build Coastguard Worker "trace_call", 733*da0073e9SAndroid Build Coastguard Worker "trace_bytecode", 734*da0073e9SAndroid Build Coastguard Worker "custom_format_test_artifact", 735*da0073e9SAndroid Build Coastguard Worker "onnx", 736*da0073e9SAndroid Build Coastguard Worker "onnx_diagnostics", 737*da0073e9SAndroid Build Coastguard Worker "guards", 738*da0073e9SAndroid Build Coastguard Worker "verbose_guards", 739*da0073e9SAndroid Build Coastguard Worker "sym_node", 740*da0073e9SAndroid Build Coastguard Worker "export", 741*da0073e9SAndroid Build Coastguard Worker "trace_shape_events", 742*da0073e9SAndroid Build Coastguard Worker "cudagraph_static_inputs", 743*da0073e9SAndroid Build Coastguard Worker "benchmarking", 744*da0073e9SAndroid Build Coastguard Worker "loop_ordering", 745*da0073e9SAndroid Build Coastguard Worker} 746*da0073e9SAndroid Build Coastguard Workerfor name in torch._logging._internal.log_registry.artifact_names: 747*da0073e9SAndroid Build Coastguard Worker if name not in exclusions: 748*da0073e9SAndroid Build Coastguard Worker setattr(LoggingTests, f"test_{name}", single_record_test(**{name: True})) 749*da0073e9SAndroid Build Coastguard Worker 750*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 751*da0073e9SAndroid Build Coastguard Worker from torch._dynamo.test_case import run_tests 752*da0073e9SAndroid Build Coastguard Worker 753*da0073e9SAndroid Build Coastguard Worker run_tests() 754