xref: /aosp_15_r20/external/pytorch/test/jit/test_with.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: jit"]
2
3import os
4import sys
5from typing import Any, List
6
7import torch
8from torch.testing._internal.common_utils import skipIfTorchDynamo
9from torch.testing._internal.jit_utils import JitTestCase, make_global
10
11
12# Make the helper files in test/ importable
13pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
14sys.path.append(pytorch_test_dir)
15
16if __name__ == "__main__":
17    raise RuntimeError(
18        "This test file is not meant to be run directly, use:\n\n"
19        "\tpython test/test_jit.py TESTNAME\n\n"
20        "instead."
21    )
22
23
24class TestWith(JitTestCase):
25    """
26    A suite of tests for with statements.
27    """
28
29    def test_with_as(self):
30        """
31        Check that with statements that use the 'as' keyword to bind expressions
32        to targets work as expected.
33        """
34
35        @torch.jit.script
36        class Context:
37            """
38            This class implements a basic context manager interface for use in
39            the unit tests. Unlike Context, the stateful part of this class
40            is a Tensor that is mutated in-place so that modifications made in the
41            JIT interpreter are visible outside of it.
42            """
43
44            def __init__(self, start: int):
45                self.count = torch.tensor([start], dtype=torch.double)
46
47            def __enter__(self):
48                self.count.add_(0.3)
49                return self.count
50
51            def __exit__(self, type: Any, value: Any, tb: Any) -> bool:
52                self.count.sub_(0.3)
53                return True
54
55        make_global(Context)
56
57        def test_basic(x: torch.Tensor) -> torch.Tensor:
58            """Basic test with one with-statement."""
59
60            c = Context(1)
61
62            with c as mult:
63                y = x + mult
64
65            y *= c.count
66            return y
67
68        def test_pass(x: torch.Tensor) -> torch.Tensor:
69            """
70            Test with a pass statement inside a with-statement. Although
71            the body of the with is empty, __enter__ and __exit__ should
72            still be called.
73            """
74            c = Context(1)
75
76            with c as mult:
77                pass
78
79            x *= c.count
80            return x
81
82        def test_early_return(x: torch.Tensor, c: Context) -> torch.Tensor:
83            """
84            Test that returning early from inside a with-statement works
85            as expected.
86            """
87            with c as mult:
88                y = x + mult
89                return y
90
91            x = y + y
92            return x
93
94        def test_conditional_early_return(x: torch.Tensor, c: Context) -> torch.Tensor:
95            """
96            Test that conditionally returning early from inside a with-statement works
97            as expected.
98            """
99            with c as mult:
100                y = x + mult
101                if mult > 0:
102                    return y
103
104            x = y + y
105            return x
106
107        def test_break(x: torch.Tensor, c: Context, l: List[int]) -> torch.Tensor:
108            """
109            Test that breaking early from inside a with-statement works
110            as expected.
111            """
112            with c as mult:
113                for a in l:
114                    if a == 0:
115                        break
116                    x += a * mult
117
118            return x
119
120        def test_continue(x: torch.Tensor, c: Context, l: List[int]) -> torch.Tensor:
121            """
122            Test that using continue inside a with-statement works
123            as expected.
124            """
125            with c as mult:
126                for a in l:
127                    if a == 0:
128                        continue
129                    x += a * mult
130
131            return x
132
133        def test_serial(x: torch.Tensor) -> torch.Tensor:
134            """
135            Test two with-statements in a row.
136            """
137            c = Context(1)
138
139            with c as mult:
140                y = x + mult
141
142            with c as mult:
143                y *= mult
144
145            return y
146
147        def test_nested(x: torch.Tensor) -> torch.Tensor:
148            """
149            Test nested with-statements.
150            """
151            c = Context(1)
152
153            with c as m:
154                with c as n:
155                    y = x + n
156
157                y *= m
158
159            return y
160
161        def test_combined(x: torch.Tensor) -> torch.Tensor:
162            """
163            Test a with-statement with multiple with items.
164            """
165            c = Context(1)
166            d = Context(2)
167
168            with c as m, d as n:
169                y = x + (m + n)
170
171            return y
172
173        test_input = torch.randn(2, 2)
174        test_context = Context(2)
175        test_list = [2, 0, 1, 3, 0, 2]
176
177        self.checkScript(test_basic, (test_input,))
178        self.checkScript(test_pass, (test_input,))
179        self.checkScript(test_early_return, (test_input, test_context))
180        self.checkScript(test_break, (test_input, test_context, test_list))
181        self.checkScript(test_continue, (test_input, test_context, test_list))
182        self.assertEqual(test_context.count, 2)
183        self.checkScript(test_serial, (test_input,))
184        self.checkScript(test_nested, (test_input,))
185        self.checkScript(test_combined, (test_input,))
186
187    def test_with_no_as(self):
188        """
189        Check that with statements that do not use the 'as' keyword to bind expressions
190        to targets work as expected.
191        """
192
193        @torch.jit.script
194        class Context:
195            """
196            This class implements a basic context manager interface for use in
197            the unit tests. Unlike Context, the stateful part of this class
198            is a Tensor that is mutated in-place so that modifications made in the
199            JIT interpreter are visible outside of it.
200            """
201
202            def __init__(self, start: int):
203                self.count = torch.tensor([start], dtype=torch.double)
204
205            def __enter__(self):
206                self.count.add_(0.3)
207                return self.count
208
209            def __exit__(self, type: Any, value: Any, tb: Any):
210                self.count.sub_(0.3)
211
212        make_global(Context)
213
214        def test_basic(x: torch.Tensor) -> torch.Tensor:
215            """Basic test with one with-statement."""
216
217            c = Context(1)
218
219            with c:
220                y = x + c.count
221
222            y *= c.count
223            return y
224
225        def test_pass(x: torch.Tensor) -> torch.Tensor:
226            """
227            Test with a pass statement inside a with-statement. Although
228            the body of the with is empty, __enter__ and __exit__ should
229            still be called.
230            """
231            c = Context(1)
232
233            with c:
234                pass
235
236            x *= c.count
237            return x
238
239        def test_early_return(x: torch.Tensor, c: Context) -> torch.Tensor:
240            """
241            Test that returning early from inside a with-statement works
242            as expected.
243            """
244            with c:
245                y = x + c.count
246                return y
247
248            x = y + y
249            return x
250
251        def test_conditional_early_return(x: torch.Tensor, c: Context) -> torch.Tensor:
252            """
253            Test that conditionally returning early from inside a with-statement works
254            as expected.
255            """
256            with c:
257                y = x + c.count
258                if c.count > 0:
259                    return y
260
261            x = y + y
262            return x
263
264        def test_break(x: torch.Tensor, c: Context, l: List[int]) -> torch.Tensor:
265            """
266            Test that breaking early from inside a with-statement works
267            as expected.
268            """
269            with c:
270                for a in l:
271                    if a == 0:
272                        break
273                    x += a * c.count
274
275            return x
276
277        def test_continue(x: torch.Tensor, c: Context, l: List[int]) -> torch.Tensor:
278            """
279            Test that using continue inside a with-statement works
280            as expected.
281            """
282            with c:
283                for a in l:
284                    if a == 0:
285                        continue
286                    x += a * c.count
287
288            return x
289
290        def test_serial(x: torch.Tensor) -> torch.Tensor:
291            """
292            Test two with-statements in a row.
293            """
294            c = Context(1)
295
296            with c:
297                y = x + c.count
298
299            with c:
300                y *= c.count
301
302            return y
303
304        def test_nested(x: torch.Tensor) -> torch.Tensor:
305            """
306            Test nested with-statements.
307            """
308            c = Context(1)
309
310            with c:
311                with c:
312                    y = x + c.count
313
314                y *= c.count
315
316            return y
317
318        def test_combined(x: torch.Tensor) -> torch.Tensor:
319            """
320            Test a with-statement with multiple with items.
321            """
322            c = Context(1)
323            d = Context(2)
324
325            with c, d:
326                y = x + (c.count + d.count)
327
328            return y
329
330        test_input = torch.randn(2, 2)
331        test_context = Context(2)
332        test_list = [2, 0, 1, 3, 0, 2]
333
334        self.checkScript(test_basic, (test_input,))
335        self.checkScript(test_pass, (test_input,))
336        self.checkScript(test_early_return, (test_input, test_context))
337        self.checkScript(test_break, (test_input, test_context, test_list))
338        self.checkScript(test_continue, (test_input, test_context, test_list))
339        self.assertEqual(test_context.count, 2)
340        self.checkScript(test_serial, (test_input,))
341        self.checkScript(test_nested, (test_input,))
342        self.checkScript(test_combined, (test_input,))
343
344    def test_with_exceptions(self):
345        """
346        Check that exceptions thrown in the bodies of with-statements are
347        handled correctly.
348        """
349
350        @torch.jit.script
351        class Context:
352            """
353            This class implements a basic context manager interface for use in
354            the unit tests. Unlike Context, the stateful part of this class
355            is a Tensor that is mutated in-place so that modifications made in the
356            JIT interpreter are visible outside of it.
357            """
358
359            def __init__(self, start: int):
360                self.count = torch.tensor([start], dtype=torch.double)
361
362            def __enter__(self):
363                self.count.add_(0.3)
364                return self.count
365
366            def __exit__(self, type: Any, value: Any, tb: Any):
367                self.count.sub_(0.3)
368
369        make_global(Context)
370
371        @torch.jit.script
372        def method_that_raises() -> torch.Tensor:
373            raise Exception("raised exception")  # noqa: TRY002
374
375        @torch.jit.script
376        def test_exception(x: torch.Tensor, c: Context) -> torch.Tensor:
377            """
378            Test the case in which an exception is thrown while executing the body of a with-statement.
379            """
380            with c as _:
381                x += method_that_raises()
382
383            return x
384
385        @torch.jit.script
386        def test_exception_nested(x: torch.Tensor, c: Context) -> torch.Tensor:
387            """
388            Test the case in which an exception is thrown while executing the body of a nested with-statement.
389            """
390            with c as _:
391                with c as _:
392                    x += method_that_raises()
393
394            return x
395
396        @torch.jit.script
397        def with_that_raises(c: Context) -> torch.Tensor:
398            a = torch.tensor([1])
399
400            with c as _:
401                a += method_that_raises()
402
403            return a
404
405        @torch.jit.script
406        def test_exception_fn_call(x: torch.Tensor, c: Context) -> torch.Tensor:
407            """
408            Test the case in which an exception is thrown while there are active with-statements in two different
409            frames.
410            """
411            with c as _:
412                x += with_that_raises(c)
413
414            return x
415
416        c = Context(1)
417
418        # checkScript and checkScriptRaisesRegex cannot be used because the string frontend will
419        # not compile class types (of which Context, the context manager being used for this test
420        # is one).
421        with self.assertRaisesRegexWithHighlight(
422            Exception, r"raised exception", 'raise Exception("raised exception'
423        ):
424            test_exception(torch.randn(2), c)
425        self.assertEqual(c.count, 1)
426
427        with self.assertRaisesRegexWithHighlight(
428            Exception, r"raised exception", 'raise Exception("raised exception'
429        ):
430            test_exception_nested(torch.randn(2), c)
431        self.assertEqual(c.count, 1)
432
433        with self.assertRaisesRegexWithHighlight(
434            Exception, r"raised exception", 'raise Exception("raised exception'
435        ):
436            test_exception_fn_call(torch.randn(2), c)
437        self.assertEqual(c.count, 1)
438
439    def test_with_errors(self):
440        """
441        Check that errors related to with-statements are detected and reported correctly.
442        """
443
444        @torch.jit.script
445        class NoEnterNoExit:
446            """
447            This class is missing __enter__ and __exit__ methods.
448            """
449
450            def __init__(self) -> None:
451                self.count = 1
452
453        @torch.jit.script
454        class BadEnter:
455            """
456            This class has an __enter__ method with an incorrect signature.
457            """
458
459            def __init__(self) -> None:
460                self.count = 1
461
462            def __enter__(self, incr: int):  # noqa: PLE0302
463                self.count += incr
464
465            def __exit__(self, type: Any, value: Any, tb: Any):
466                pass
467
468        @torch.jit.script
469        class BadExit:
470            """
471            This class has an __exit__ method with an incorrect signature.
472            """
473
474            def __init__(self) -> None:
475                self.count = 1
476
477            def __enter__(self):
478                self.count += 1
479
480            def __exit__(self, type: Any, value: Any):  # noqa: PLE0302
481                pass
482
483        @torch.jit.script
484        class ExitIncorrectTypes:
485            """
486            This class has an __exit__ method with unsupported argument types.
487            """
488
489            def __init__(self) -> None:
490                self.count = 1
491
492            def __enter__(self):
493                self.count += 1
494
495            def __exit__(self, type: Any, value: int, tb: int):
496                pass
497
498        def test_no_enter_no_exit(x: torch.Tensor, cm: NoEnterNoExit) -> torch.Tensor:
499            with cm as _:
500                pass
501
502            return x
503
504        def test_bad_enter(x: torch.Tensor, cm: BadEnter) -> torch.Tensor:
505            with cm as _:
506                pass
507
508            return x
509
510        def test_bad_exit(x: torch.Tensor, cm: BadExit) -> torch.Tensor:
511            with cm as _:
512                pass
513
514            return x
515
516        def test_exit_incorrect_types(
517            x: torch.Tensor, cm: ExitIncorrectTypes
518        ) -> torch.Tensor:
519            with cm as _:
520                pass
521
522            return x
523
524        def test_enter_without_object():
525            with "not_object" as obj:
526                pass
527
528        test_tensor = torch.randn(5, dtype=torch.double)
529
530        with self.assertRaisesRegexWithHighlight(
531            RuntimeError, r"does not define __enter__ and __exit__ methods", "cm"
532        ):
533            self.checkScript(test_no_enter_no_exit, (test_tensor, NoEnterNoExit()))
534
535        with self.assertRaisesRegexWithHighlight(
536            RuntimeError,
537            r"__enter__ must have only one argument and one return value",
538            "cm",
539        ):
540            self.checkScript(test_bad_enter, (test_tensor, BadEnter()))
541
542        with self.assertRaisesRegexWithHighlight(
543            RuntimeError, r"__exit__ must have four arguments", "cm"
544        ):
545            self.checkScript(test_bad_exit, (test_tensor, BadExit()))
546
547        with self.assertRaisesRegexWithHighlight(
548            RuntimeError, r"argument 2 of __exit__ must have Any type", "cm"
549        ):
550            self.checkScript(
551                test_exit_incorrect_types, (test_tensor, ExitIncorrectTypes())
552            )
553
554        with self.assertRaisesRegexWithHighlight(
555            RuntimeError, r"must return an object", '"not_object"'
556        ):
557            self.checkScript(test_enter_without_object, ())
558
559    def test_with_no_grad(self):
560        """
561        Check that torch.no_grad() works. Most of these are adapted from
562        corresponding tests for eager-mode no_grad.
563        """
564
565        # Basic no_grad test.
566        def test_no_grad(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
567            with torch.no_grad():
568                w = x + y
569
570            return w
571
572        s = torch.jit.script(test_no_grad)
573        x = torch.ones(5, 5, requires_grad=True)
574        y = torch.ones(5, 5) * 4
575        w = s(x, y)
576
577        self.assertFalse(w.requires_grad)
578        self.assertRaises(RuntimeError, lambda: w.backward(torch.ones(5, 5)))
579        self.assertIsNone(w.grad_fn)
580
581        # Test assignment of a grad-less Tensor to a Tensor with gradients
582        # in a no_grad block.
583        def test_no_grad_assignment(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
584            with torch.no_grad():
585                x[0] = y
586
587            return x
588
589        s = torch.jit.script(test_no_grad_assignment)
590        z = torch.randn(5)
591        w = s(x, z)
592        self.assertTrue(w.requires_grad)
593        self.assertIsNone(w.grad_fn)
594
595        # Check that @torch.jit.ignored functions respect no_grad when it is
596        # called in JIT mode.
597        class NoGradModule(torch.nn.Module):
598            @torch.jit.ignore
599            def adder(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
600                w = x + y
601                return w
602
603            def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
604                with torch.no_grad():
605                    w = self.adder(x, y)
606
607                return w
608
609        s = torch.jit.script(NoGradModule())
610        w = s(x, y)
611
612        self.assertFalse(w.requires_grad)
613
614    @skipIfTorchDynamo("Torchdynamo cannot correctly handle profiler.profile calls")
615    def test_with_record_function(self):
616        """
617        Check that torch.autograd.profiler.record_function context manager is
618        torchscriptable.
619        """
620
621        def with_rf(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
622            with torch.autograd.profiler.record_function("foo"):
623                # Nested record_function.
624                with torch.autograd.profiler.record_function("nested"):
625                    a = x + y
626            return a
627
628        scripted = torch.jit.script(with_rf)
629        x, y = torch.ones(2), torch.ones(2)
630        with torch.autograd.profiler.profile() as p:
631            scripted(x, y)
632
633        # Need to call below to populate CPU children.
634        p.key_averages()
635        function_events = p.function_events
636        # Event with name "foo" should be recorded.
637        rf_events = [evt for evt in function_events if evt.name == "foo"]
638        self.assertEqual(len(rf_events), 1)
639        rf_event = rf_events[0]
640        child_events = rf_event.cpu_children
641        # Ensure we find nested record_function event
642        self.assertTrue("nested" in (child.name for child in child_events))
643        nested_function_event = [
644            evt for evt in function_events if evt.name == "nested"
645        ][0]
646        # Nested record function should have child "aten::add"
647        nested_child_events = nested_function_event.cpu_children
648        self.assertTrue("aten::add" in (child.name for child in nested_child_events))
649