xref: /aosp_15_r20/external/pytorch/test/dynamo/test_logging.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: dynamo"]
2*da0073e9SAndroid Build Coastguard Workerimport contextlib
3*da0073e9SAndroid Build Coastguard Workerimport functools
4*da0073e9SAndroid Build Coastguard Workerimport logging
5*da0073e9SAndroid Build Coastguard Workerimport os
6*da0073e9SAndroid Build Coastguard Workerimport unittest.mock
7*da0073e9SAndroid Build Coastguard Worker
8*da0073e9SAndroid Build Coastguard Workerimport torch
9*da0073e9SAndroid Build Coastguard Workerimport torch._dynamo.test_case
10*da0073e9SAndroid Build Coastguard Workerimport torch._dynamo.testing
11*da0073e9SAndroid Build Coastguard Workerimport torch.distributed as dist
12*da0073e9SAndroid Build Coastguard Workerfrom torch._dynamo.testing import empty_line_normalizer, skipIfNotPy311
13*da0073e9SAndroid Build Coastguard Workerfrom torch._dynamo.trace_rules import _as_posix_path
14*da0073e9SAndroid Build Coastguard Workerfrom torch.nn.parallel import DistributedDataParallel as DDP
15*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import (
16*da0073e9SAndroid Build Coastguard Worker    find_free_port,
17*da0073e9SAndroid Build Coastguard Worker    munge_exc,
18*da0073e9SAndroid Build Coastguard Worker    skipIfTorchDynamo,
19*da0073e9SAndroid Build Coastguard Worker)
20*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.inductor_utils import HAS_CUDA
21*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.logging_utils import (
22*da0073e9SAndroid Build Coastguard Worker    LoggingTestCase,
23*da0073e9SAndroid Build Coastguard Worker    make_logging_test,
24*da0073e9SAndroid Build Coastguard Worker    make_settings_test,
25*da0073e9SAndroid Build Coastguard Worker)
26*da0073e9SAndroid Build Coastguard Worker
27*da0073e9SAndroid Build Coastguard Worker
28*da0073e9SAndroid Build Coastguard Workerrequires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda")
29*da0073e9SAndroid Build Coastguard Workerrequires_distributed = functools.partial(
30*da0073e9SAndroid Build Coastguard Worker    unittest.skipIf, not dist.is_available(), "requires distributed"
31*da0073e9SAndroid Build Coastguard Worker)
32*da0073e9SAndroid Build Coastguard Worker
33*da0073e9SAndroid Build Coastguard Worker
34*da0073e9SAndroid Build Coastguard Workerdef example_fn(a):
35*da0073e9SAndroid Build Coastguard Worker    output = a.mul(torch.ones(1000, 1000))
36*da0073e9SAndroid Build Coastguard Worker    output = output.add(torch.ones(1000, 1000))
37*da0073e9SAndroid Build Coastguard Worker    return output
38*da0073e9SAndroid Build Coastguard Worker
39*da0073e9SAndroid Build Coastguard Worker
40*da0073e9SAndroid Build Coastguard Workerdef dynamo_error_fn(a):
41*da0073e9SAndroid Build Coastguard Worker    output = a.mul(torch.ones(1000, 1000))
42*da0073e9SAndroid Build Coastguard Worker    output = output.add(torch.ones(10, 10))
43*da0073e9SAndroid Build Coastguard Worker    return output
44*da0073e9SAndroid Build Coastguard Worker
45*da0073e9SAndroid Build Coastguard Worker
46*da0073e9SAndroid Build Coastguard Workerdef inductor_error_fn(a):
47*da0073e9SAndroid Build Coastguard Worker    output = torch.round(a)
48*da0073e9SAndroid Build Coastguard Worker    return output
49*da0073e9SAndroid Build Coastguard Worker
50*da0073e9SAndroid Build Coastguard Worker
51*da0073e9SAndroid Build Coastguard Workerdef inductor_schedule_fn(a):
52*da0073e9SAndroid Build Coastguard Worker    output = a.add(torch.ones(1000, 1000, device="cuda"))
53*da0073e9SAndroid Build Coastguard Worker    return output
54*da0073e9SAndroid Build Coastguard Worker
55*da0073e9SAndroid Build Coastguard Worker
56*da0073e9SAndroid Build Coastguard WorkerARGS = (torch.ones(1000, 1000, requires_grad=True),)
57*da0073e9SAndroid Build Coastguard Worker
58*da0073e9SAndroid Build Coastguard Worker
59*da0073e9SAndroid Build Coastguard Workerdef multi_record_test(num_records, **kwargs):
60*da0073e9SAndroid Build Coastguard Worker    @make_logging_test(**kwargs)
61*da0073e9SAndroid Build Coastguard Worker    def fn(self, records):
62*da0073e9SAndroid Build Coastguard Worker        fn_opt = torch._dynamo.optimize("inductor")(example_fn)
63*da0073e9SAndroid Build Coastguard Worker        fn_opt(*ARGS)
64*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(records), num_records)
65*da0073e9SAndroid Build Coastguard Worker
66*da0073e9SAndroid Build Coastguard Worker    return fn
67*da0073e9SAndroid Build Coastguard Worker
68*da0073e9SAndroid Build Coastguard Worker
69*da0073e9SAndroid Build Coastguard Workerdef within_range_record_test(num_records_lower, num_records_higher, **kwargs):
70*da0073e9SAndroid Build Coastguard Worker    @make_logging_test(**kwargs)
71*da0073e9SAndroid Build Coastguard Worker    def fn(self, records):
72*da0073e9SAndroid Build Coastguard Worker        fn_opt = torch._dynamo.optimize("inductor")(example_fn)
73*da0073e9SAndroid Build Coastguard Worker        fn_opt(*ARGS)
74*da0073e9SAndroid Build Coastguard Worker        self.assertGreaterEqual(len(records), num_records_lower)
75*da0073e9SAndroid Build Coastguard Worker        self.assertLessEqual(len(records), num_records_higher)
76*da0073e9SAndroid Build Coastguard Worker
77*da0073e9SAndroid Build Coastguard Worker    return fn
78*da0073e9SAndroid Build Coastguard Worker
79*da0073e9SAndroid Build Coastguard Worker
80*da0073e9SAndroid Build Coastguard Workerdef single_record_test(**kwargs):
81*da0073e9SAndroid Build Coastguard Worker    return multi_record_test(1, **kwargs)
82*da0073e9SAndroid Build Coastguard Worker
83*da0073e9SAndroid Build Coastguard Worker
84*da0073e9SAndroid Build Coastguard Workerclass LoggingTests(LoggingTestCase):
85*da0073e9SAndroid Build Coastguard Worker    test_bytecode = multi_record_test(2, bytecode=True)
86*da0073e9SAndroid Build Coastguard Worker    test_output_code = multi_record_test(2, output_code=True)
87*da0073e9SAndroid Build Coastguard Worker    test_aot_graphs = multi_record_test(3, aot_graphs=True)
88*da0073e9SAndroid Build Coastguard Worker
89*da0073e9SAndroid Build Coastguard Worker    @requires_cuda
90*da0073e9SAndroid Build Coastguard Worker    @make_logging_test(schedule=True)
91*da0073e9SAndroid Build Coastguard Worker    def test_schedule(self, records):
92*da0073e9SAndroid Build Coastguard Worker        fn_opt = torch._dynamo.optimize("inductor")(inductor_schedule_fn)
93*da0073e9SAndroid Build Coastguard Worker        fn_opt(torch.ones(1000, 1000, device="cuda"))
94*da0073e9SAndroid Build Coastguard Worker        self.assertGreater(len(records), 0)
95*da0073e9SAndroid Build Coastguard Worker        self.assertLess(len(records), 5)
96*da0073e9SAndroid Build Coastguard Worker
97*da0073e9SAndroid Build Coastguard Worker    @requires_cuda
98*da0073e9SAndroid Build Coastguard Worker    @make_logging_test(fusion=True)
99*da0073e9SAndroid Build Coastguard Worker    def test_fusion(self, records):
100*da0073e9SAndroid Build Coastguard Worker        fn_opt = torch._dynamo.optimize("inductor")(inductor_schedule_fn)
101*da0073e9SAndroid Build Coastguard Worker        fn_opt(torch.ones(1000, 1000, device="cuda"))
102*da0073e9SAndroid Build Coastguard Worker        self.assertGreater(len(records), 0)
103*da0073e9SAndroid Build Coastguard Worker        self.assertLess(len(records), 8)
104*da0073e9SAndroid Build Coastguard Worker
105*da0073e9SAndroid Build Coastguard Worker    @requires_cuda
106*da0073e9SAndroid Build Coastguard Worker    @make_logging_test(cudagraphs=True)
107*da0073e9SAndroid Build Coastguard Worker    def test_cudagraphs(self, records):
108*da0073e9SAndroid Build Coastguard Worker        fn_opt = torch.compile(mode="reduce-overhead")(inductor_schedule_fn)
109*da0073e9SAndroid Build Coastguard Worker        fn_opt(torch.ones(1000, 1000, device="cuda"))
110*da0073e9SAndroid Build Coastguard Worker        self.assertGreater(len(records), 0)
111*da0073e9SAndroid Build Coastguard Worker        self.assertLess(len(records), 8)
112*da0073e9SAndroid Build Coastguard Worker
113*da0073e9SAndroid Build Coastguard Worker    @make_logging_test(recompiles=True)
114*da0073e9SAndroid Build Coastguard Worker    def test_recompiles(self, records):
115*da0073e9SAndroid Build Coastguard Worker        def fn(x, y):
116*da0073e9SAndroid Build Coastguard Worker            return torch.add(x, y)
117*da0073e9SAndroid Build Coastguard Worker
118*da0073e9SAndroid Build Coastguard Worker        fn_opt = torch._dynamo.optimize("inductor")(fn)
119*da0073e9SAndroid Build Coastguard Worker        fn_opt(torch.ones(1000, 1000), torch.ones(1000, 1000))
120*da0073e9SAndroid Build Coastguard Worker        fn_opt(torch.ones(1000, 1000), 1)
121*da0073e9SAndroid Build Coastguard Worker        self.assertGreater(len(records), 0)
122*da0073e9SAndroid Build Coastguard Worker
123*da0073e9SAndroid Build Coastguard Worker    test_dynamo_debug = within_range_record_test(30, 90, dynamo=logging.DEBUG)
124*da0073e9SAndroid Build Coastguard Worker    test_dynamo_info = within_range_record_test(2, 10, dynamo=logging.INFO)
125*da0073e9SAndroid Build Coastguard Worker
126*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("too slow")
127*da0073e9SAndroid Build Coastguard Worker    @make_logging_test(dynamo=logging.DEBUG)
128*da0073e9SAndroid Build Coastguard Worker    def test_dynamo_debug_default_off_artifacts(self, records):
129*da0073e9SAndroid Build Coastguard Worker        fn_opt = torch._dynamo.optimize("inductor")(example_fn)
130*da0073e9SAndroid Build Coastguard Worker        fn_opt(torch.ones(1000, 1000))
131*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len([r for r in records if ".__bytecode" in r.name]), 0)
132*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len([r for r in records if ".__output_code" in r.name]), 0)
133*da0073e9SAndroid Build Coastguard Worker
134*da0073e9SAndroid Build Coastguard Worker    @make_logging_test()
135*da0073e9SAndroid Build Coastguard Worker    def test_dynamo_error(self, records):
136*da0073e9SAndroid Build Coastguard Worker        try:
137*da0073e9SAndroid Build Coastguard Worker            fn_opt = torch._dynamo.optimize("inductor")(dynamo_error_fn)
138*da0073e9SAndroid Build Coastguard Worker            fn_opt(*ARGS)
139*da0073e9SAndroid Build Coastguard Worker        except Exception:
140*da0073e9SAndroid Build Coastguard Worker            pass
141*da0073e9SAndroid Build Coastguard Worker        record = self.getRecord(records, "WON'T CONVERT")
142*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
143*da0073e9SAndroid Build Coastguard Worker            munge_exc(record.getMessage()),
144*da0073e9SAndroid Build Coastguard Worker            """\
145*da0073e9SAndroid Build Coastguard WorkerWON'T CONVERT dynamo_error_fn test_logging.py line N
146*da0073e9SAndroid Build Coastguard Workerdue to:
147*da0073e9SAndroid Build Coastguard WorkerTraceback (most recent call last):
148*da0073e9SAndroid Build Coastguard Workertorch._dynamo.exc.TorchRuntimeError: Failed running call_method add(*(FakeTensor(..., size=(1000, 1000), grad_fn=<MulBackward0>), FakeTensor(..., size=(10, 10))), **{}):
149*da0073e9SAndroid Build Coastguard WorkerAttempting to broadcast a dimension of length 10 at -1! Mismatching argument at index 1 had torch.Size([10, 10]); but expected shape should be broadcastable to [1000, 1000]
150*da0073e9SAndroid Build Coastguard Worker
151*da0073e9SAndroid Build Coastguard Workerfrom user code:
152*da0073e9SAndroid Build Coastguard Worker   File "test_logging.py", line N, in dynamo_error_fn
153*da0073e9SAndroid Build Coastguard Worker    output = output.add(torch.ones(10, 10))""",  # noqa: B950
154*da0073e9SAndroid Build Coastguard Worker        )
155*da0073e9SAndroid Build Coastguard Worker
156*da0073e9SAndroid Build Coastguard Worker    test_aot = within_range_record_test(2, 6, aot=logging.INFO)
157*da0073e9SAndroid Build Coastguard Worker    test_inductor_debug = within_range_record_test(3, 17, inductor=logging.DEBUG)
158*da0073e9SAndroid Build Coastguard Worker    test_inductor_info = within_range_record_test(2, 4, inductor=logging.INFO)
159*da0073e9SAndroid Build Coastguard Worker
160*da0073e9SAndroid Build Coastguard Worker    @make_logging_test()
161*da0073e9SAndroid Build Coastguard Worker    def test_inductor_error(self, records):
162*da0073e9SAndroid Build Coastguard Worker        exitstack = contextlib.ExitStack()
163*da0073e9SAndroid Build Coastguard Worker        import torch._inductor.lowering
164*da0073e9SAndroid Build Coastguard Worker
165*da0073e9SAndroid Build Coastguard Worker        def throw(x):
166*da0073e9SAndroid Build Coastguard Worker            raise AssertionError
167*da0073e9SAndroid Build Coastguard Worker
168*da0073e9SAndroid Build Coastguard Worker        # inject an error in the lowerings
169*da0073e9SAndroid Build Coastguard Worker        dict_entries = {}
170*da0073e9SAndroid Build Coastguard Worker        for x in list(torch._inductor.lowering.lowerings.keys()):
171*da0073e9SAndroid Build Coastguard Worker            if "round" in x.__name__:
172*da0073e9SAndroid Build Coastguard Worker                dict_entries[x] = throw
173*da0073e9SAndroid Build Coastguard Worker
174*da0073e9SAndroid Build Coastguard Worker        exitstack.enter_context(
175*da0073e9SAndroid Build Coastguard Worker            unittest.mock.patch.dict(torch._inductor.lowering.lowerings, dict_entries)
176*da0073e9SAndroid Build Coastguard Worker        )
177*da0073e9SAndroid Build Coastguard Worker
178*da0073e9SAndroid Build Coastguard Worker        try:
179*da0073e9SAndroid Build Coastguard Worker            fn_opt = torch._dynamo.optimize("inductor")(inductor_error_fn)
180*da0073e9SAndroid Build Coastguard Worker            fn_opt(*ARGS)
181*da0073e9SAndroid Build Coastguard Worker        except Exception:
182*da0073e9SAndroid Build Coastguard Worker            pass
183*da0073e9SAndroid Build Coastguard Worker        record = self.getRecord(records, "WON'T CONVERT")
184*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
185*da0073e9SAndroid Build Coastguard Worker            munge_exc(record.getMessage()),
186*da0073e9SAndroid Build Coastguard Worker            """\
187*da0073e9SAndroid Build Coastguard WorkerWON'T CONVERT inductor_error_fn test_logging.py line N
188*da0073e9SAndroid Build Coastguard Workerdue to:
189*da0073e9SAndroid Build Coastguard WorkerTraceback (most recent call last):
190*da0073e9SAndroid Build Coastguard Worker  File "test_logging.py", line N, in throw
191*da0073e9SAndroid Build Coastguard Worker    raise AssertionError
192*da0073e9SAndroid Build Coastguard Workertorch._inductor.exc.LoweringException: AssertionError:
193*da0073e9SAndroid Build Coastguard Worker  target: aten.round.default
194*da0073e9SAndroid Build Coastguard Worker  args[0]: TensorBox(StorageBox(
195*da0073e9SAndroid Build Coastguard Worker    InputBuffer(name='primals_1', layout=FixedLayout('cpu', torch.float32, size=[1000, 1000], stride=[1000, 1]))
196*da0073e9SAndroid Build Coastguard Worker  ))
197*da0073e9SAndroid Build Coastguard Worker
198*da0073e9SAndroid Build Coastguard WorkerThe above exception was the direct cause of the following exception:
199*da0073e9SAndroid Build Coastguard Worker
200*da0073e9SAndroid Build Coastguard WorkerTraceback (most recent call last):
201*da0073e9SAndroid Build Coastguard Workertorch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
202*da0073e9SAndroid Build Coastguard WorkerLoweringException: AssertionError:
203*da0073e9SAndroid Build Coastguard Worker  target: aten.round.default
204*da0073e9SAndroid Build Coastguard Worker  args[0]: TensorBox(StorageBox(
205*da0073e9SAndroid Build Coastguard Worker    InputBuffer(name='primals_1', layout=FixedLayout('cpu', torch.float32, size=[1000, 1000], stride=[1000, 1]))
206*da0073e9SAndroid Build Coastguard Worker  ))""",
207*da0073e9SAndroid Build Coastguard Worker        )
208*da0073e9SAndroid Build Coastguard Worker
209*da0073e9SAndroid Build Coastguard Worker        exitstack.close()
210*da0073e9SAndroid Build Coastguard Worker
211*da0073e9SAndroid Build Coastguard Worker    @requires_distributed()
212*da0073e9SAndroid Build Coastguard Worker    @requires_cuda
213*da0073e9SAndroid Build Coastguard Worker    @make_logging_test(ddp_graphs=True)
214*da0073e9SAndroid Build Coastguard Worker    def test_ddp_graphs(self, records):
215*da0073e9SAndroid Build Coastguard Worker        class ToyModel(torch.nn.Module):
216*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
217*da0073e9SAndroid Build Coastguard Worker                super().__init__()
218*da0073e9SAndroid Build Coastguard Worker                self.layers = torch.nn.Sequential(
219*da0073e9SAndroid Build Coastguard Worker                    torch.nn.Linear(1024, 1024),
220*da0073e9SAndroid Build Coastguard Worker                    torch.nn.Linear(1024, 1024),
221*da0073e9SAndroid Build Coastguard Worker                )
222*da0073e9SAndroid Build Coastguard Worker
223*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
224*da0073e9SAndroid Build Coastguard Worker                return self.layers(x)
225*da0073e9SAndroid Build Coastguard Worker
226*da0073e9SAndroid Build Coastguard Worker        os.environ["MASTER_ADDR"] = "localhost"
227*da0073e9SAndroid Build Coastguard Worker        os.environ["MASTER_PORT"] = str(find_free_port())
228*da0073e9SAndroid Build Coastguard Worker        dist.init_process_group("gloo", rank=0, world_size=1)
229*da0073e9SAndroid Build Coastguard Worker
230*da0073e9SAndroid Build Coastguard Worker        ddp_model = torch._dynamo.optimize("inductor")(
231*da0073e9SAndroid Build Coastguard Worker            DDP(ToyModel().to("cuda:0"), device_ids=[0], bucket_cap_mb=4)
232*da0073e9SAndroid Build Coastguard Worker        )
233*da0073e9SAndroid Build Coastguard Worker
234*da0073e9SAndroid Build Coastguard Worker        ddp_model(torch.randn(1024, 1024, device="cuda:0"))
235*da0073e9SAndroid Build Coastguard Worker
236*da0073e9SAndroid Build Coastguard Worker        dist.destroy_process_group()
237*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len([r for r in records if "__ddp_graphs" in r.name]), 4)
238*da0073e9SAndroid Build Coastguard Worker
239*da0073e9SAndroid Build Coastguard Worker    # check that logging to a child log of a registered logger
240*da0073e9SAndroid Build Coastguard Worker    # does not register it and result in duplicated records
241*da0073e9SAndroid Build Coastguard Worker    @make_settings_test("torch._dynamo.output_graph")
242*da0073e9SAndroid Build Coastguard Worker    def test_open_registration_with_registered_parent(self, records):
243*da0073e9SAndroid Build Coastguard Worker        logger = logging.getLogger("torch._dynamo.output_graph")
244*da0073e9SAndroid Build Coastguard Worker        logger.info("hi")
245*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(records), 1)
246*da0073e9SAndroid Build Coastguard Worker
247*da0073e9SAndroid Build Coastguard Worker    # check logging to a random log that is not a child log of a registered
248*da0073e9SAndroid Build Coastguard Worker    # logger registers it and sets handlers properly
249*da0073e9SAndroid Build Coastguard Worker    @make_settings_test("torch.utils")
250*da0073e9SAndroid Build Coastguard Worker    def test_open_registration(self, records):
251*da0073e9SAndroid Build Coastguard Worker        logger = logging.getLogger("torch.utils")
252*da0073e9SAndroid Build Coastguard Worker        logger.info("hi")
253*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(records), 1)
254*da0073e9SAndroid Build Coastguard Worker
255*da0073e9SAndroid Build Coastguard Worker    # check logging to a random log that is not a child log of a registered
256*da0073e9SAndroid Build Coastguard Worker    # logger registers it and sets handlers properly
257*da0073e9SAndroid Build Coastguard Worker    @make_logging_test(modules={"torch.utils": logging.INFO})
258*da0073e9SAndroid Build Coastguard Worker    def test_open_registration_python_api(self, records):
259*da0073e9SAndroid Build Coastguard Worker        logger = logging.getLogger("torch.utils")
260*da0073e9SAndroid Build Coastguard Worker        logger.info("hi")
261*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(records), 1)
262*da0073e9SAndroid Build Coastguard Worker
263*da0073e9SAndroid Build Coastguard Worker    @make_logging_test(all=logging.DEBUG, dynamo=logging.INFO)
264*da0073e9SAndroid Build Coastguard Worker    def test_all(self, _):
265*da0073e9SAndroid Build Coastguard Worker        registry = torch._logging._internal.log_registry
266*da0073e9SAndroid Build Coastguard Worker
267*da0073e9SAndroid Build Coastguard Worker        dynamo_qnames = registry.log_alias_to_log_qnames["dynamo"]
268*da0073e9SAndroid Build Coastguard Worker        for logger_qname in torch._logging._internal.log_registry.get_log_qnames():
269*da0073e9SAndroid Build Coastguard Worker            logger = logging.getLogger(logger_qname)
270*da0073e9SAndroid Build Coastguard Worker
271*da0073e9SAndroid Build Coastguard Worker            # if logger_qname is a.b.c and dynamo_qnames contains a.b, it still matches dynamo's INFO setting
272*da0073e9SAndroid Build Coastguard Worker            if any(logger_qname.find(d) == 0 for d in dynamo_qnames):
273*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(
274*da0073e9SAndroid Build Coastguard Worker                    logger.getEffectiveLevel(),
275*da0073e9SAndroid Build Coastguard Worker                    logging.INFO,
276*da0073e9SAndroid Build Coastguard Worker                    msg=f"expected {logger_qname} is INFO, got {logging.getLevelName(logger.getEffectiveLevel())}",
277*da0073e9SAndroid Build Coastguard Worker                )
278*da0073e9SAndroid Build Coastguard Worker            else:
279*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(
280*da0073e9SAndroid Build Coastguard Worker                    logger.getEffectiveLevel(),
281*da0073e9SAndroid Build Coastguard Worker                    logging.DEBUG,
282*da0073e9SAndroid Build Coastguard Worker                    msg=f"expected {logger_qname} is DEBUG, got {logging.getLevelName(logger.getEffectiveLevel())}",
283*da0073e9SAndroid Build Coastguard Worker                )
284*da0073e9SAndroid Build Coastguard Worker
285*da0073e9SAndroid Build Coastguard Worker    @make_logging_test(graph_breaks=True)
286*da0073e9SAndroid Build Coastguard Worker    def test_graph_breaks(self, records):
287*da0073e9SAndroid Build Coastguard Worker        @torch._dynamo.optimize("inductor")
288*da0073e9SAndroid Build Coastguard Worker        def fn(x):
289*da0073e9SAndroid Build Coastguard Worker            torch._dynamo.graph_break()
290*da0073e9SAndroid Build Coastguard Worker            return x + 1
291*da0073e9SAndroid Build Coastguard Worker
292*da0073e9SAndroid Build Coastguard Worker        fn(torch.ones(1))
293*da0073e9SAndroid Build Coastguard Worker
294*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(records), 1)
295*da0073e9SAndroid Build Coastguard Worker
296*da0073e9SAndroid Build Coastguard Worker    @make_settings_test("torch._dynamo.utils")
297*da0073e9SAndroid Build Coastguard Worker    def test_dump_compile_times(self, records):
298*da0073e9SAndroid Build Coastguard Worker        fn_opt = torch._dynamo.optimize("inductor")(example_fn)
299*da0073e9SAndroid Build Coastguard Worker        fn_opt(torch.ones(1000, 1000))
300*da0073e9SAndroid Build Coastguard Worker        # This function runs during exit via atexit.register.
301*da0073e9SAndroid Build Coastguard Worker        # We're not actually going to run atexit._run_exit_funcs() here,
302*da0073e9SAndroid Build Coastguard Worker        # because it'll destroy state necessary for other tests.
303*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.utils.dump_compile_times()
304*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
305*da0073e9SAndroid Build Coastguard Worker            len(
306*da0073e9SAndroid Build Coastguard Worker                [r for r in records if "TorchDynamo compilation metrics" in str(r.msg)]
307*da0073e9SAndroid Build Coastguard Worker            ),
308*da0073e9SAndroid Build Coastguard Worker            1,
309*da0073e9SAndroid Build Coastguard Worker        )
310*da0073e9SAndroid Build Coastguard Worker
311*da0073e9SAndroid Build Coastguard Worker    @make_logging_test(dynamo=logging.INFO)
312*da0073e9SAndroid Build Coastguard Worker    def test_custom_format_exc(self, records):
313*da0073e9SAndroid Build Coastguard Worker        dynamo_log = logging.getLogger(torch._dynamo.__name__)
314*da0073e9SAndroid Build Coastguard Worker        try:
315*da0073e9SAndroid Build Coastguard Worker            raise RuntimeError("foo")
316*da0073e9SAndroid Build Coastguard Worker        except RuntimeError:
317*da0073e9SAndroid Build Coastguard Worker            dynamo_log.exception("test dynamo")
318*da0073e9SAndroid Build Coastguard Worker            dynamo_log.info("with exc", exc_info=True)
319*da0073e9SAndroid Build Coastguard Worker        dynamo_log.info("with stack", stack_info=True)
320*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(records), 3)
321*da0073e9SAndroid Build Coastguard Worker        # unfortunately there's no easy way to test the final formatted log other than
322*da0073e9SAndroid Build Coastguard Worker        # to ask the dynamo logger's handler to format it.
323*da0073e9SAndroid Build Coastguard Worker        for handler in dynamo_log.handlers:
324*da0073e9SAndroid Build Coastguard Worker            if torch._logging._internal._is_torch_handler(handler):
325*da0073e9SAndroid Build Coastguard Worker                break
326*da0073e9SAndroid Build Coastguard Worker        self.assertIsNotNone(handler)
327*da0073e9SAndroid Build Coastguard Worker        self.assertIn("Traceback", handler.format(records[0]))
328*da0073e9SAndroid Build Coastguard Worker        self.assertIn("Traceback", handler.format(records[1]))
329*da0073e9SAndroid Build Coastguard Worker        self.assertIn("Stack", handler.format(records[2]))
330*da0073e9SAndroid Build Coastguard Worker
331*da0073e9SAndroid Build Coastguard Worker    @make_logging_test(dynamo=logging.INFO)
332*da0073e9SAndroid Build Coastguard Worker    def test_custom_format(self, records):
333*da0073e9SAndroid Build Coastguard Worker        dynamo_log = logging.getLogger(torch._dynamo.__name__)
334*da0073e9SAndroid Build Coastguard Worker        test_log = torch._logging.getArtifactLogger(
335*da0073e9SAndroid Build Coastguard Worker            torch._dynamo.__name__, "custom_format_test_artifact"
336*da0073e9SAndroid Build Coastguard Worker        )
337*da0073e9SAndroid Build Coastguard Worker        dynamo_log.info("test dynamo")
338*da0073e9SAndroid Build Coastguard Worker        test_log.info("custom format")
339*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(records), 2)
340*da0073e9SAndroid Build Coastguard Worker        # unfortunately there's no easy way to test the final formatted log other than
341*da0073e9SAndroid Build Coastguard Worker        # to ask the dynamo logger's handler to format it.
342*da0073e9SAndroid Build Coastguard Worker        for handler in dynamo_log.handlers:
343*da0073e9SAndroid Build Coastguard Worker            if torch._logging._internal._is_torch_handler(handler):
344*da0073e9SAndroid Build Coastguard Worker                break
345*da0073e9SAndroid Build Coastguard Worker        self.assertIsNotNone(handler)
346*da0073e9SAndroid Build Coastguard Worker        self.assertIn("I", handler.format(records[0]))
347*da0073e9SAndroid Build Coastguard Worker        self.assertEqual("custom format", handler.format(records[1]))
348*da0073e9SAndroid Build Coastguard Worker
349*da0073e9SAndroid Build Coastguard Worker    @make_logging_test(dynamo=logging.INFO)
350*da0073e9SAndroid Build Coastguard Worker    def test_multiline_format(self, records):
351*da0073e9SAndroid Build Coastguard Worker        dynamo_log = logging.getLogger(torch._dynamo.__name__)
352*da0073e9SAndroid Build Coastguard Worker        dynamo_log.info("test\ndynamo")
353*da0073e9SAndroid Build Coastguard Worker        dynamo_log.info("%s", "test\ndynamo")
354*da0073e9SAndroid Build Coastguard Worker        dynamo_log.info("test\n%s", "test\ndynamo")
355*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(records), 3)
356*da0073e9SAndroid Build Coastguard Worker        # unfortunately there's no easy way to test the final formatted log other than
357*da0073e9SAndroid Build Coastguard Worker        # to ask the dynamo logger's handler to format it.
358*da0073e9SAndroid Build Coastguard Worker        for handler in dynamo_log.handlers:
359*da0073e9SAndroid Build Coastguard Worker            if torch._logging._internal._is_torch_handler(handler):
360*da0073e9SAndroid Build Coastguard Worker                break
361*da0073e9SAndroid Build Coastguard Worker        self.assertIsNotNone(handler)
362*da0073e9SAndroid Build Coastguard Worker        for record in records:
363*da0073e9SAndroid Build Coastguard Worker            r = handler.format(record)
364*da0073e9SAndroid Build Coastguard Worker            for l in r.splitlines():
365*da0073e9SAndroid Build Coastguard Worker                self.assertIn("I", l)
366*da0073e9SAndroid Build Coastguard Worker
367*da0073e9SAndroid Build Coastguard Worker    test_trace_source_simple = within_range_record_test(1, 100, trace_source=True)
368*da0073e9SAndroid Build Coastguard Worker
369*da0073e9SAndroid Build Coastguard Worker    @make_logging_test(trace_source=True)
370*da0073e9SAndroid Build Coastguard Worker    def test_trace_source_if_stmt(self, records):
371*da0073e9SAndroid Build Coastguard Worker        def fn(x):
372*da0073e9SAndroid Build Coastguard Worker            if x.sum() > 0:
373*da0073e9SAndroid Build Coastguard Worker                return x * 2
374*da0073e9SAndroid Build Coastguard Worker            return x * 3
375*da0073e9SAndroid Build Coastguard Worker
376*da0073e9SAndroid Build Coastguard Worker        fn_opt = torch._dynamo.optimize("eager")(fn)
377*da0073e9SAndroid Build Coastguard Worker        fn_opt(torch.ones(3, 3))
378*da0073e9SAndroid Build Coastguard Worker
379*da0073e9SAndroid Build Coastguard Worker        found_x2 = False
380*da0073e9SAndroid Build Coastguard Worker        found_x3 = False
381*da0073e9SAndroid Build Coastguard Worker        for record in records:
382*da0073e9SAndroid Build Coastguard Worker            msg = record.getMessage()
383*da0073e9SAndroid Build Coastguard Worker            if "return x * 2" in msg:
384*da0073e9SAndroid Build Coastguard Worker                found_x2 = True
385*da0073e9SAndroid Build Coastguard Worker            if "return x * 3" in msg:
386*da0073e9SAndroid Build Coastguard Worker                found_x3 = True
387*da0073e9SAndroid Build Coastguard Worker
388*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(found_x2)
389*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(found_x3)
390*da0073e9SAndroid Build Coastguard Worker
391*da0073e9SAndroid Build Coastguard Worker    @make_logging_test(trace_source=True)
392*da0073e9SAndroid Build Coastguard Worker    def test_trace_source_nested(self, records):
393*da0073e9SAndroid Build Coastguard Worker        def fn1(x):
394*da0073e9SAndroid Build Coastguard Worker            x = fn2(x)
395*da0073e9SAndroid Build Coastguard Worker            return x * 2
396*da0073e9SAndroid Build Coastguard Worker
397*da0073e9SAndroid Build Coastguard Worker        def fn2(x):
398*da0073e9SAndroid Build Coastguard Worker            x = fn3(x)
399*da0073e9SAndroid Build Coastguard Worker            return x * 3
400*da0073e9SAndroid Build Coastguard Worker
401*da0073e9SAndroid Build Coastguard Worker        def fn3(x):
402*da0073e9SAndroid Build Coastguard Worker            return x * 4
403*da0073e9SAndroid Build Coastguard Worker
404*da0073e9SAndroid Build Coastguard Worker        fn_opt = torch._dynamo.optimize("eager")(fn1)
405*da0073e9SAndroid Build Coastguard Worker        fn_opt(torch.ones(3, 3))
406*da0073e9SAndroid Build Coastguard Worker
407*da0073e9SAndroid Build Coastguard Worker        found_x2 = False
408*da0073e9SAndroid Build Coastguard Worker        found_x3 = False
409*da0073e9SAndroid Build Coastguard Worker        found_x4 = False
410*da0073e9SAndroid Build Coastguard Worker        for record in records:
411*da0073e9SAndroid Build Coastguard Worker            msg = record.getMessage()
412*da0073e9SAndroid Build Coastguard Worker            if "return x * 2" in msg:
413*da0073e9SAndroid Build Coastguard Worker                found_x2 = True
414*da0073e9SAndroid Build Coastguard Worker                self.assertNotIn("inline depth", msg)
415*da0073e9SAndroid Build Coastguard Worker            elif "return x * 3" in msg:
416*da0073e9SAndroid Build Coastguard Worker                found_x3 = True
417*da0073e9SAndroid Build Coastguard Worker                self.assertIn("inline depth: 1", msg)
418*da0073e9SAndroid Build Coastguard Worker            elif "return x * 4" in msg:
419*da0073e9SAndroid Build Coastguard Worker                found_x4 = True
420*da0073e9SAndroid Build Coastguard Worker                self.assertIn("inline depth: 2", msg)
421*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(found_x2)
422*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(found_x3)
423*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(found_x4)
424*da0073e9SAndroid Build Coastguard Worker
425*da0073e9SAndroid Build Coastguard Worker    @make_logging_test(trace_source=True)
426*da0073e9SAndroid Build Coastguard Worker    def test_trace_source_cond(self, records):
427*da0073e9SAndroid Build Coastguard Worker        from functorch.experimental.control_flow import cond
428*da0073e9SAndroid Build Coastguard Worker
429*da0073e9SAndroid Build Coastguard Worker        def true_fn(x):
430*da0073e9SAndroid Build Coastguard Worker            return x * 2
431*da0073e9SAndroid Build Coastguard Worker
432*da0073e9SAndroid Build Coastguard Worker        def false_fn(x):
433*da0073e9SAndroid Build Coastguard Worker            return x * 3
434*da0073e9SAndroid Build Coastguard Worker
435*da0073e9SAndroid Build Coastguard Worker        def inner(pred, x):
436*da0073e9SAndroid Build Coastguard Worker            return cond(pred, true_fn, false_fn, [x])
437*da0073e9SAndroid Build Coastguard Worker
438*da0073e9SAndroid Build Coastguard Worker        def outer(pred, x):
439*da0073e9SAndroid Build Coastguard Worker            return inner(pred, x)
440*da0073e9SAndroid Build Coastguard Worker
441*da0073e9SAndroid Build Coastguard Worker        fn_opt = torch._dynamo.optimize("eager")(outer)
442*da0073e9SAndroid Build Coastguard Worker        fn_opt(torch.tensor(True), torch.ones(3, 3))
443*da0073e9SAndroid Build Coastguard Worker
444*da0073e9SAndroid Build Coastguard Worker        found_x2 = False
445*da0073e9SAndroid Build Coastguard Worker        found_x3 = False
446*da0073e9SAndroid Build Coastguard Worker        for record in records:
447*da0073e9SAndroid Build Coastguard Worker            msg = record.getMessage()
448*da0073e9SAndroid Build Coastguard Worker            if "return x * 2" in msg:
449*da0073e9SAndroid Build Coastguard Worker                found_x2 = True
450*da0073e9SAndroid Build Coastguard Worker                self.assertIn("inline depth: 3", msg)
451*da0073e9SAndroid Build Coastguard Worker            if "return x * 3" in msg:
452*da0073e9SAndroid Build Coastguard Worker                found_x3 = True
453*da0073e9SAndroid Build Coastguard Worker                self.assertIn("inline depth: 3", msg)
454*da0073e9SAndroid Build Coastguard Worker
455*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(found_x2)
456*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(found_x3)
457*da0073e9SAndroid Build Coastguard Worker
458*da0073e9SAndroid Build Coastguard Worker    @make_logging_test(trace_source=True)
459*da0073e9SAndroid Build Coastguard Worker    def test_trace_source_funcname(self, records):
460*da0073e9SAndroid Build Coastguard Worker        # NOTE: list comprehensions are inlined in 3.12, so test with tuples
461*da0073e9SAndroid Build Coastguard Worker        def fn1():
462*da0073e9SAndroid Build Coastguard Worker            def fn2():
463*da0073e9SAndroid Build Coastguard Worker                if True:
464*da0073e9SAndroid Build Coastguard Worker                    return tuple(torch.ones(3, 3) for _ in range(5))
465*da0073e9SAndroid Build Coastguard Worker                return None
466*da0073e9SAndroid Build Coastguard Worker
467*da0073e9SAndroid Build Coastguard Worker            return fn2()
468*da0073e9SAndroid Build Coastguard Worker
469*da0073e9SAndroid Build Coastguard Worker        fn_opt = torch._dynamo.optimize("eager")(fn1)
470*da0073e9SAndroid Build Coastguard Worker        fn_opt()
471*da0073e9SAndroid Build Coastguard Worker
472*da0073e9SAndroid Build Coastguard Worker        found_funcname = False
473*da0073e9SAndroid Build Coastguard Worker        for record in records:
474*da0073e9SAndroid Build Coastguard Worker            msg = record.getMessage()
475*da0073e9SAndroid Build Coastguard Worker            if "<genexpr>" in msg and "fn1.fn2" in msg:
476*da0073e9SAndroid Build Coastguard Worker                found_funcname = True
477*da0073e9SAndroid Build Coastguard Worker
478*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(found_funcname)
479*da0073e9SAndroid Build Coastguard Worker
480*da0073e9SAndroid Build Coastguard Worker    def test_invalid_artifact_flag(self):
481*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(ValueError):
482*da0073e9SAndroid Build Coastguard Worker            torch._logging.set_logs(aot_graphs=5)
483*da0073e9SAndroid Build Coastguard Worker
484*da0073e9SAndroid Build Coastguard Worker    @requires_distributed()
485*da0073e9SAndroid Build Coastguard Worker    def test_distributed_rank_logging(self):
486*da0073e9SAndroid Build Coastguard Worker        env = dict(os.environ)
487*da0073e9SAndroid Build Coastguard Worker        env["TORCH_LOGS"] = "dynamo"
488*da0073e9SAndroid Build Coastguard Worker        stdout, stderr = self.run_process_no_exception(
489*da0073e9SAndroid Build Coastguard Worker            """\
490*da0073e9SAndroid Build Coastguard Workerimport torch.distributed as dist
491*da0073e9SAndroid Build Coastguard Workerimport logging
492*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.distributed.fake_pg import FakeStore
493*da0073e9SAndroid Build Coastguard Workerstore = FakeStore()
494*da0073e9SAndroid Build Coastguard Workerdist.init_process_group("fake", rank=0, world_size=2, store=store)
495*da0073e9SAndroid Build Coastguard Workerdynamo_log = logging.getLogger("torch._dynamo")
496*da0073e9SAndroid Build Coastguard Workerdynamo_log.info("woof")
497*da0073e9SAndroid Build Coastguard Workerprint("arf")
498*da0073e9SAndroid Build Coastguard Worker""",
499*da0073e9SAndroid Build Coastguard Worker            env=env,
500*da0073e9SAndroid Build Coastguard Worker        )
501*da0073e9SAndroid Build Coastguard Worker        self.assertIn("[rank0]:", stderr.decode("utf-8"))
502*da0073e9SAndroid Build Coastguard Worker
503*da0073e9SAndroid Build Coastguard Worker    @skipIfNotPy311
504*da0073e9SAndroid Build Coastguard Worker    @make_logging_test(trace_call=True)
505*da0073e9SAndroid Build Coastguard Worker    def test_trace_call(self, records):
506*da0073e9SAndroid Build Coastguard Worker        def fn(x, y):
507*da0073e9SAndroid Build Coastguard Worker            return (x * 2) @ (y * 3)
508*da0073e9SAndroid Build Coastguard Worker
509*da0073e9SAndroid Build Coastguard Worker        fn_opt = torch._dynamo.optimize("eager")(fn)
510*da0073e9SAndroid Build Coastguard Worker        fn_opt(torch.randn(10, 20), torch.randn(20, 30))
511*da0073e9SAndroid Build Coastguard Worker
512*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(records), 3)
513*da0073e9SAndroid Build Coastguard Worker        # only get last 2 lines
514*da0073e9SAndroid Build Coastguard Worker        messages = [
515*da0073e9SAndroid Build Coastguard Worker            "\n".join(record.getMessage().split("\n")[-2:]) for record in records
516*da0073e9SAndroid Build Coastguard Worker        ]
517*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
518*da0073e9SAndroid Build Coastguard Worker            messages[0],
519*da0073e9SAndroid Build Coastguard Worker            """\
520*da0073e9SAndroid Build Coastguard Worker            return (x * 2) @ (y * 3)
521*da0073e9SAndroid Build Coastguard Worker                    ~~^~~""",
522*da0073e9SAndroid Build Coastguard Worker        )
523*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
524*da0073e9SAndroid Build Coastguard Worker            messages[1],
525*da0073e9SAndroid Build Coastguard Worker            """\
526*da0073e9SAndroid Build Coastguard Worker            return (x * 2) @ (y * 3)
527*da0073e9SAndroid Build Coastguard Worker                              ~~^~~""",
528*da0073e9SAndroid Build Coastguard Worker        )
529*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
530*da0073e9SAndroid Build Coastguard Worker            messages[2],
531*da0073e9SAndroid Build Coastguard Worker            """\
532*da0073e9SAndroid Build Coastguard Worker            return (x * 2) @ (y * 3)
533*da0073e9SAndroid Build Coastguard Worker                   ~~~~~~~~^~~~~~~~~""",
534*da0073e9SAndroid Build Coastguard Worker        )
535*da0073e9SAndroid Build Coastguard Worker
536*da0073e9SAndroid Build Coastguard Worker    @skipIfNotPy311
537*da0073e9SAndroid Build Coastguard Worker    @make_logging_test(trace_call=True)
538*da0073e9SAndroid Build Coastguard Worker    def test_trace_call_inline_call(self, records):
539*da0073e9SAndroid Build Coastguard Worker        def g(x):
540*da0073e9SAndroid Build Coastguard Worker            return x * 2
541*da0073e9SAndroid Build Coastguard Worker
542*da0073e9SAndroid Build Coastguard Worker        def f(x):
543*da0073e9SAndroid Build Coastguard Worker            return g(g(x))
544*da0073e9SAndroid Build Coastguard Worker
545*da0073e9SAndroid Build Coastguard Worker        fn_opt = torch._dynamo.optimize("eager")(f)
546*da0073e9SAndroid Build Coastguard Worker        fn_opt(torch.randn(3, 3))
547*da0073e9SAndroid Build Coastguard Worker
548*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(records), 4)
549*da0073e9SAndroid Build Coastguard Worker        messages = [
550*da0073e9SAndroid Build Coastguard Worker            "\n".join(record.getMessage().split("\n")[-2:]) for record in records
551*da0073e9SAndroid Build Coastguard Worker        ]
552*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
553*da0073e9SAndroid Build Coastguard Worker            messages[0],
554*da0073e9SAndroid Build Coastguard Worker            """\
555*da0073e9SAndroid Build Coastguard Worker            return g(g(x))
556*da0073e9SAndroid Build Coastguard Worker                     ~^^^""",
557*da0073e9SAndroid Build Coastguard Worker        )
558*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
559*da0073e9SAndroid Build Coastguard Worker            messages[1],
560*da0073e9SAndroid Build Coastguard Worker            """\
561*da0073e9SAndroid Build Coastguard Worker            return x * 2
562*da0073e9SAndroid Build Coastguard Worker                   ~~^~~""",
563*da0073e9SAndroid Build Coastguard Worker        )
564*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
565*da0073e9SAndroid Build Coastguard Worker            messages[2],
566*da0073e9SAndroid Build Coastguard Worker            """\
567*da0073e9SAndroid Build Coastguard Worker            return g(g(x))
568*da0073e9SAndroid Build Coastguard Worker                   ~^^^^^^""",
569*da0073e9SAndroid Build Coastguard Worker        )
570*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
571*da0073e9SAndroid Build Coastguard Worker            messages[3],
572*da0073e9SAndroid Build Coastguard Worker            """\
573*da0073e9SAndroid Build Coastguard Worker            return x * 2
574*da0073e9SAndroid Build Coastguard Worker                   ~~^~~""",
575*da0073e9SAndroid Build Coastguard Worker        )
576*da0073e9SAndroid Build Coastguard Worker
577*da0073e9SAndroid Build Coastguard Worker    @skipIfNotPy311
578*da0073e9SAndroid Build Coastguard Worker    @make_logging_test(trace_call=True)
579*da0073e9SAndroid Build Coastguard Worker    def test_trace_call_graph_break(self, records):
580*da0073e9SAndroid Build Coastguard Worker        def fn(x):
581*da0073e9SAndroid Build Coastguard Worker            x = x * 2
582*da0073e9SAndroid Build Coastguard Worker            torch._dynamo.graph_break()
583*da0073e9SAndroid Build Coastguard Worker            return x * 3
584*da0073e9SAndroid Build Coastguard Worker
585*da0073e9SAndroid Build Coastguard Worker        fn_opt = torch._dynamo.optimize("eager")(fn)
586*da0073e9SAndroid Build Coastguard Worker        fn_opt(torch.randn(3, 3))
587*da0073e9SAndroid Build Coastguard Worker
588*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(records), 3)
589*da0073e9SAndroid Build Coastguard Worker        messages = [
590*da0073e9SAndroid Build Coastguard Worker            "\n".join(record.getMessage().split("\n")[-2:]) for record in records
591*da0073e9SAndroid Build Coastguard Worker        ]
592*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
593*da0073e9SAndroid Build Coastguard Worker            messages[0],
594*da0073e9SAndroid Build Coastguard Worker            """\
595*da0073e9SAndroid Build Coastguard Worker            x = x * 2
596*da0073e9SAndroid Build Coastguard Worker                ~~^~~""",
597*da0073e9SAndroid Build Coastguard Worker        )
598*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
599*da0073e9SAndroid Build Coastguard Worker            messages[-1],
600*da0073e9SAndroid Build Coastguard Worker            """\
601*da0073e9SAndroid Build Coastguard Worker            return x * 3
602*da0073e9SAndroid Build Coastguard Worker                   ~~^~~""",
603*da0073e9SAndroid Build Coastguard Worker        )
604*da0073e9SAndroid Build Coastguard Worker
605*da0073e9SAndroid Build Coastguard Worker    @make_logging_test(guards=True, recompiles=True)
606*da0073e9SAndroid Build Coastguard Worker    def test_guards_recompiles(self, records):
607*da0073e9SAndroid Build Coastguard Worker        def fn(x, ys, zs):
608*da0073e9SAndroid Build Coastguard Worker            return inner(x, ys, zs)
609*da0073e9SAndroid Build Coastguard Worker
610*da0073e9SAndroid Build Coastguard Worker        def inner(x, ys, zs):
611*da0073e9SAndroid Build Coastguard Worker            for y, z in zip(ys, zs):
612*da0073e9SAndroid Build Coastguard Worker                x += y * z
613*da0073e9SAndroid Build Coastguard Worker            return x
614*da0073e9SAndroid Build Coastguard Worker
615*da0073e9SAndroid Build Coastguard Worker        ys = [1.0, 2.0]
616*da0073e9SAndroid Build Coastguard Worker        zs = [3.0]
617*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor([1.0])
618*da0073e9SAndroid Build Coastguard Worker
619*da0073e9SAndroid Build Coastguard Worker        fn_opt = torch._dynamo.optimize("eager")(fn)
620*da0073e9SAndroid Build Coastguard Worker        fn_opt(x, ys, zs)
621*da0073e9SAndroid Build Coastguard Worker        fn_opt(x, ys[:1], zs)
622*da0073e9SAndroid Build Coastguard Worker
623*da0073e9SAndroid Build Coastguard Worker        record_str = "\n".join(r.getMessage() for r in records)
624*da0073e9SAndroid Build Coastguard Worker
625*da0073e9SAndroid Build Coastguard Worker        self.assertIn(
626*da0073e9SAndroid Build Coastguard Worker            """L['zs'][0] == 3.0""",
627*da0073e9SAndroid Build Coastguard Worker            record_str,
628*da0073e9SAndroid Build Coastguard Worker        )
629*da0073e9SAndroid Build Coastguard Worker        self.assertIn(
630*da0073e9SAndroid Build Coastguard Worker            "len(L['ys']) == 2",
631*da0073e9SAndroid Build Coastguard Worker            record_str,
632*da0073e9SAndroid Build Coastguard Worker        )
633*da0073e9SAndroid Build Coastguard Worker
634*da0073e9SAndroid Build Coastguard Worker    @make_logging_test(cudagraph_static_inputs=True)
635*da0073e9SAndroid Build Coastguard Worker    def test_cudagraph_static_inputs(self, records):
636*da0073e9SAndroid Build Coastguard Worker        @torch.compile(mode="reduce-overhead")
637*da0073e9SAndroid Build Coastguard Worker        def fn(x):
638*da0073e9SAndroid Build Coastguard Worker            return x + 1
639*da0073e9SAndroid Build Coastguard Worker
640*da0073e9SAndroid Build Coastguard Worker        x = torch.ones(2, 2)
641*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.mark_static_address(x)
642*da0073e9SAndroid Build Coastguard Worker        fn(x)
643*da0073e9SAndroid Build Coastguard Worker        self.assertGreater(len(records), 0)
644*da0073e9SAndroid Build Coastguard Worker        self.assertLess(len(records), 4)
645*da0073e9SAndroid Build Coastguard Worker
646*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("too slow")
647*da0073e9SAndroid Build Coastguard Worker    @make_logging_test(**torch._logging.DEFAULT_LOGGING)
648*da0073e9SAndroid Build Coastguard Worker    def test_default_logging(self, records):
649*da0073e9SAndroid Build Coastguard Worker        def fn(a):
650*da0073e9SAndroid Build Coastguard Worker            if a.sum() < 0:
651*da0073e9SAndroid Build Coastguard Worker                a = torch.sin(a)
652*da0073e9SAndroid Build Coastguard Worker            else:
653*da0073e9SAndroid Build Coastguard Worker                a = torch.cos(a)
654*da0073e9SAndroid Build Coastguard Worker            print("hello")
655*da0073e9SAndroid Build Coastguard Worker            return a + 1
656*da0073e9SAndroid Build Coastguard Worker
657*da0073e9SAndroid Build Coastguard Worker        fn_opt = torch._dynamo.optimize("eager")(fn)
658*da0073e9SAndroid Build Coastguard Worker        fn_opt(torch.ones(10, 10))
659*da0073e9SAndroid Build Coastguard Worker        fn_opt(-torch.ones(10, 5))
660*da0073e9SAndroid Build Coastguard Worker
661*da0073e9SAndroid Build Coastguard Worker        self.assertGreater(len([r for r in records if ".__graph_breaks" in r.name]), 0)
662*da0073e9SAndroid Build Coastguard Worker        self.assertGreater(len([r for r in records if ".__recompiles" in r.name]), 0)
663*da0073e9SAndroid Build Coastguard Worker        self.assertGreater(len([r for r in records if ".symbolic_shapes" in r.name]), 0)
664*da0073e9SAndroid Build Coastguard Worker        self.assertGreater(len([r for r in records if ".__guards" in r.name]), 0)
665*da0073e9SAndroid Build Coastguard Worker        self.assertGreater(
666*da0073e9SAndroid Build Coastguard Worker            len([r for r in records if "return a + 1" in r.getMessage()]), 0
667*da0073e9SAndroid Build Coastguard Worker        )
668*da0073e9SAndroid Build Coastguard Worker
669*da0073e9SAndroid Build Coastguard Worker    def test_logs_out(self):
670*da0073e9SAndroid Build Coastguard Worker        import tempfile
671*da0073e9SAndroid Build Coastguard Worker
672*da0073e9SAndroid Build Coastguard Worker        with tempfile.NamedTemporaryFile(delete=False) as tmp:
673*da0073e9SAndroid Build Coastguard Worker            file_path = _as_posix_path(tmp.name)
674*da0073e9SAndroid Build Coastguard Worker            """
675*da0073e9SAndroid Build Coastguard Worker            NamedTemporaryFile will include a file open operation.
676*da0073e9SAndroid Build Coastguard Worker            On Windowsm the file is opened by NamedTemporaryFile, the
677*da0073e9SAndroid Build Coastguard Worker            following run_process_no_exception can't access a opened file.
678*da0073e9SAndroid Build Coastguard Worker            And then, raise a PermissionError: [Errno 13] Permission denied: [file_path]
679*da0073e9SAndroid Build Coastguard Worker            """
680*da0073e9SAndroid Build Coastguard Worker            tmp.close()
681*da0073e9SAndroid Build Coastguard Worker            env = dict(os.environ)
682*da0073e9SAndroid Build Coastguard Worker            env["TORCH_LOGS"] = "dynamo"
683*da0073e9SAndroid Build Coastguard Worker            env["TORCH_LOGS_OUT"] = file_path
684*da0073e9SAndroid Build Coastguard Worker            stdout, stderr = self.run_process_no_exception(
685*da0073e9SAndroid Build Coastguard Worker                """\
686*da0073e9SAndroid Build Coastguard Workerimport torch
687*da0073e9SAndroid Build Coastguard Worker@torch.compile(backend="eager")
688*da0073e9SAndroid Build Coastguard Workerdef fn(a):
689*da0073e9SAndroid Build Coastguard Worker    return a.sum()
690*da0073e9SAndroid Build Coastguard Worker
691*da0073e9SAndroid Build Coastguard Workerfn(torch.randn(5))
692*da0073e9SAndroid Build Coastguard Worker                """,
693*da0073e9SAndroid Build Coastguard Worker                env=env,
694*da0073e9SAndroid Build Coastguard Worker            )
695*da0073e9SAndroid Build Coastguard Worker            with open(
696*da0073e9SAndroid Build Coastguard Worker                file_path, encoding="utf-8"
697*da0073e9SAndroid Build Coastguard Worker            ) as fd:  # encoding file to UTF-8 for Windows.
698*da0073e9SAndroid Build Coastguard Worker                lines = fd.read()
699*da0073e9SAndroid Build Coastguard Worker                fd.close()
700*da0073e9SAndroid Build Coastguard Worker                os.remove(
701*da0073e9SAndroid Build Coastguard Worker                    file_path
702*da0073e9SAndroid Build Coastguard Worker                )  # Delete temp file manually, due to setup NamedTemporaryFile as delete=False.
703*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(  # process wrap difference: /r/n on Windows, /n on posix.
704*da0073e9SAndroid Build Coastguard Worker                    empty_line_normalizer(lines),
705*da0073e9SAndroid Build Coastguard Worker                    empty_line_normalizer(stderr.decode("utf-8")),
706*da0073e9SAndroid Build Coastguard Worker                )
707*da0073e9SAndroid Build Coastguard Worker
708*da0073e9SAndroid Build Coastguard Worker
709*da0073e9SAndroid Build Coastguard Worker# single record tests
710*da0073e9SAndroid Build Coastguard Workerexclusions = {
711*da0073e9SAndroid Build Coastguard Worker    "bytecode",
712*da0073e9SAndroid Build Coastguard Worker    "cudagraphs",
713*da0073e9SAndroid Build Coastguard Worker    "output_code",
714*da0073e9SAndroid Build Coastguard Worker    "schedule",
715*da0073e9SAndroid Build Coastguard Worker    "fusion",
716*da0073e9SAndroid Build Coastguard Worker    "overlap",
717*da0073e9SAndroid Build Coastguard Worker    "aot_graphs",
718*da0073e9SAndroid Build Coastguard Worker    "aot_graphs_effects",
719*da0073e9SAndroid Build Coastguard Worker    "post_grad_graphs",
720*da0073e9SAndroid Build Coastguard Worker    "compiled_autograd",
721*da0073e9SAndroid Build Coastguard Worker    "compiled_autograd_verbose",
722*da0073e9SAndroid Build Coastguard Worker    "recompiles",
723*da0073e9SAndroid Build Coastguard Worker    "recompiles_verbose",
724*da0073e9SAndroid Build Coastguard Worker    "graph_breaks",
725*da0073e9SAndroid Build Coastguard Worker    "graph",
726*da0073e9SAndroid Build Coastguard Worker    "graph_code",
727*da0073e9SAndroid Build Coastguard Worker    "graph_sizes",
728*da0073e9SAndroid Build Coastguard Worker    "ddp_graphs",
729*da0073e9SAndroid Build Coastguard Worker    "perf_hints",
730*da0073e9SAndroid Build Coastguard Worker    "not_implemented",
731*da0073e9SAndroid Build Coastguard Worker    "trace_source",
732*da0073e9SAndroid Build Coastguard Worker    "trace_call",
733*da0073e9SAndroid Build Coastguard Worker    "trace_bytecode",
734*da0073e9SAndroid Build Coastguard Worker    "custom_format_test_artifact",
735*da0073e9SAndroid Build Coastguard Worker    "onnx",
736*da0073e9SAndroid Build Coastguard Worker    "onnx_diagnostics",
737*da0073e9SAndroid Build Coastguard Worker    "guards",
738*da0073e9SAndroid Build Coastguard Worker    "verbose_guards",
739*da0073e9SAndroid Build Coastguard Worker    "sym_node",
740*da0073e9SAndroid Build Coastguard Worker    "export",
741*da0073e9SAndroid Build Coastguard Worker    "trace_shape_events",
742*da0073e9SAndroid Build Coastguard Worker    "cudagraph_static_inputs",
743*da0073e9SAndroid Build Coastguard Worker    "benchmarking",
744*da0073e9SAndroid Build Coastguard Worker    "loop_ordering",
745*da0073e9SAndroid Build Coastguard Worker}
746*da0073e9SAndroid Build Coastguard Workerfor name in torch._logging._internal.log_registry.artifact_names:
747*da0073e9SAndroid Build Coastguard Worker    if name not in exclusions:
748*da0073e9SAndroid Build Coastguard Worker        setattr(LoggingTests, f"test_{name}", single_record_test(**{name: True}))
749*da0073e9SAndroid Build Coastguard Worker
750*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
751*da0073e9SAndroid Build Coastguard Worker    from torch._dynamo.test_case import run_tests
752*da0073e9SAndroid Build Coastguard Worker
753*da0073e9SAndroid Build Coastguard Worker    run_tests()
754