xref: /aosp_15_r20/external/pytorch/test/jit/test_with.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["oncall: jit"]
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerimport os
4*da0073e9SAndroid Build Coastguard Workerimport sys
5*da0073e9SAndroid Build Coastguard Workerfrom typing import Any, List
6*da0073e9SAndroid Build Coastguard Worker
7*da0073e9SAndroid Build Coastguard Workerimport torch
8*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import skipIfTorchDynamo
9*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.jit_utils import JitTestCase, make_global
10*da0073e9SAndroid Build Coastguard Worker
11*da0073e9SAndroid Build Coastguard Worker
12*da0073e9SAndroid Build Coastguard Worker# Make the helper files in test/ importable
13*da0073e9SAndroid Build Coastguard Workerpytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
14*da0073e9SAndroid Build Coastguard Workersys.path.append(pytorch_test_dir)
15*da0073e9SAndroid Build Coastguard Worker
16*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
17*da0073e9SAndroid Build Coastguard Worker    raise RuntimeError(
18*da0073e9SAndroid Build Coastguard Worker        "This test file is not meant to be run directly, use:\n\n"
19*da0073e9SAndroid Build Coastguard Worker        "\tpython test/test_jit.py TESTNAME\n\n"
20*da0073e9SAndroid Build Coastguard Worker        "instead."
21*da0073e9SAndroid Build Coastguard Worker    )
22*da0073e9SAndroid Build Coastguard Worker
23*da0073e9SAndroid Build Coastguard Worker
24*da0073e9SAndroid Build Coastguard Workerclass TestWith(JitTestCase):
25*da0073e9SAndroid Build Coastguard Worker    """
26*da0073e9SAndroid Build Coastguard Worker    A suite of tests for with statements.
27*da0073e9SAndroid Build Coastguard Worker    """
28*da0073e9SAndroid Build Coastguard Worker
29*da0073e9SAndroid Build Coastguard Worker    def test_with_as(self):
30*da0073e9SAndroid Build Coastguard Worker        """
31*da0073e9SAndroid Build Coastguard Worker        Check that with statements that use the 'as' keyword to bind expressions
32*da0073e9SAndroid Build Coastguard Worker        to targets work as expected.
33*da0073e9SAndroid Build Coastguard Worker        """
34*da0073e9SAndroid Build Coastguard Worker
35*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
36*da0073e9SAndroid Build Coastguard Worker        class Context:
37*da0073e9SAndroid Build Coastguard Worker            """
38*da0073e9SAndroid Build Coastguard Worker            This class implements a basic context manager interface for use in
39*da0073e9SAndroid Build Coastguard Worker            the unit tests. Unlike Context, the stateful part of this class
40*da0073e9SAndroid Build Coastguard Worker            is a Tensor that is mutated in-place so that modifications made in the
41*da0073e9SAndroid Build Coastguard Worker            JIT interpreter are visible outside of it.
42*da0073e9SAndroid Build Coastguard Worker            """
43*da0073e9SAndroid Build Coastguard Worker
44*da0073e9SAndroid Build Coastguard Worker            def __init__(self, start: int):
45*da0073e9SAndroid Build Coastguard Worker                self.count = torch.tensor([start], dtype=torch.double)
46*da0073e9SAndroid Build Coastguard Worker
47*da0073e9SAndroid Build Coastguard Worker            def __enter__(self):
48*da0073e9SAndroid Build Coastguard Worker                self.count.add_(0.3)
49*da0073e9SAndroid Build Coastguard Worker                return self.count
50*da0073e9SAndroid Build Coastguard Worker
51*da0073e9SAndroid Build Coastguard Worker            def __exit__(self, type: Any, value: Any, tb: Any) -> bool:
52*da0073e9SAndroid Build Coastguard Worker                self.count.sub_(0.3)
53*da0073e9SAndroid Build Coastguard Worker                return True
54*da0073e9SAndroid Build Coastguard Worker
55*da0073e9SAndroid Build Coastguard Worker        make_global(Context)
56*da0073e9SAndroid Build Coastguard Worker
57*da0073e9SAndroid Build Coastguard Worker        def test_basic(x: torch.Tensor) -> torch.Tensor:
58*da0073e9SAndroid Build Coastguard Worker            """Basic test with one with-statement."""
59*da0073e9SAndroid Build Coastguard Worker
60*da0073e9SAndroid Build Coastguard Worker            c = Context(1)
61*da0073e9SAndroid Build Coastguard Worker
62*da0073e9SAndroid Build Coastguard Worker            with c as mult:
63*da0073e9SAndroid Build Coastguard Worker                y = x + mult
64*da0073e9SAndroid Build Coastguard Worker
65*da0073e9SAndroid Build Coastguard Worker            y *= c.count
66*da0073e9SAndroid Build Coastguard Worker            return y
67*da0073e9SAndroid Build Coastguard Worker
68*da0073e9SAndroid Build Coastguard Worker        def test_pass(x: torch.Tensor) -> torch.Tensor:
69*da0073e9SAndroid Build Coastguard Worker            """
70*da0073e9SAndroid Build Coastguard Worker            Test with a pass statement inside a with-statement. Although
71*da0073e9SAndroid Build Coastguard Worker            the body of the with is empty, __enter__ and __exit__ should
72*da0073e9SAndroid Build Coastguard Worker            still be called.
73*da0073e9SAndroid Build Coastguard Worker            """
74*da0073e9SAndroid Build Coastguard Worker            c = Context(1)
75*da0073e9SAndroid Build Coastguard Worker
76*da0073e9SAndroid Build Coastguard Worker            with c as mult:
77*da0073e9SAndroid Build Coastguard Worker                pass
78*da0073e9SAndroid Build Coastguard Worker
79*da0073e9SAndroid Build Coastguard Worker            x *= c.count
80*da0073e9SAndroid Build Coastguard Worker            return x
81*da0073e9SAndroid Build Coastguard Worker
82*da0073e9SAndroid Build Coastguard Worker        def test_early_return(x: torch.Tensor, c: Context) -> torch.Tensor:
83*da0073e9SAndroid Build Coastguard Worker            """
84*da0073e9SAndroid Build Coastguard Worker            Test that returning early from inside a with-statement works
85*da0073e9SAndroid Build Coastguard Worker            as expected.
86*da0073e9SAndroid Build Coastguard Worker            """
87*da0073e9SAndroid Build Coastguard Worker            with c as mult:
88*da0073e9SAndroid Build Coastguard Worker                y = x + mult
89*da0073e9SAndroid Build Coastguard Worker                return y
90*da0073e9SAndroid Build Coastguard Worker
91*da0073e9SAndroid Build Coastguard Worker            x = y + y
92*da0073e9SAndroid Build Coastguard Worker            return x
93*da0073e9SAndroid Build Coastguard Worker
94*da0073e9SAndroid Build Coastguard Worker        def test_conditional_early_return(x: torch.Tensor, c: Context) -> torch.Tensor:
95*da0073e9SAndroid Build Coastguard Worker            """
96*da0073e9SAndroid Build Coastguard Worker            Test that conditionally returning early from inside a with-statement works
97*da0073e9SAndroid Build Coastguard Worker            as expected.
98*da0073e9SAndroid Build Coastguard Worker            """
99*da0073e9SAndroid Build Coastguard Worker            with c as mult:
100*da0073e9SAndroid Build Coastguard Worker                y = x + mult
101*da0073e9SAndroid Build Coastguard Worker                if mult > 0:
102*da0073e9SAndroid Build Coastguard Worker                    return y
103*da0073e9SAndroid Build Coastguard Worker
104*da0073e9SAndroid Build Coastguard Worker            x = y + y
105*da0073e9SAndroid Build Coastguard Worker            return x
106*da0073e9SAndroid Build Coastguard Worker
107*da0073e9SAndroid Build Coastguard Worker        def test_break(x: torch.Tensor, c: Context, l: List[int]) -> torch.Tensor:
108*da0073e9SAndroid Build Coastguard Worker            """
109*da0073e9SAndroid Build Coastguard Worker            Test that breaking early from inside a with-statement works
110*da0073e9SAndroid Build Coastguard Worker            as expected.
111*da0073e9SAndroid Build Coastguard Worker            """
112*da0073e9SAndroid Build Coastguard Worker            with c as mult:
113*da0073e9SAndroid Build Coastguard Worker                for a in l:
114*da0073e9SAndroid Build Coastguard Worker                    if a == 0:
115*da0073e9SAndroid Build Coastguard Worker                        break
116*da0073e9SAndroid Build Coastguard Worker                    x += a * mult
117*da0073e9SAndroid Build Coastguard Worker
118*da0073e9SAndroid Build Coastguard Worker            return x
119*da0073e9SAndroid Build Coastguard Worker
120*da0073e9SAndroid Build Coastguard Worker        def test_continue(x: torch.Tensor, c: Context, l: List[int]) -> torch.Tensor:
121*da0073e9SAndroid Build Coastguard Worker            """
122*da0073e9SAndroid Build Coastguard Worker            Test that using continue inside a with-statement works
123*da0073e9SAndroid Build Coastguard Worker            as expected.
124*da0073e9SAndroid Build Coastguard Worker            """
125*da0073e9SAndroid Build Coastguard Worker            with c as mult:
126*da0073e9SAndroid Build Coastguard Worker                for a in l:
127*da0073e9SAndroid Build Coastguard Worker                    if a == 0:
128*da0073e9SAndroid Build Coastguard Worker                        continue
129*da0073e9SAndroid Build Coastguard Worker                    x += a * mult
130*da0073e9SAndroid Build Coastguard Worker
131*da0073e9SAndroid Build Coastguard Worker            return x
132*da0073e9SAndroid Build Coastguard Worker
133*da0073e9SAndroid Build Coastguard Worker        def test_serial(x: torch.Tensor) -> torch.Tensor:
134*da0073e9SAndroid Build Coastguard Worker            """
135*da0073e9SAndroid Build Coastguard Worker            Test two with-statements in a row.
136*da0073e9SAndroid Build Coastguard Worker            """
137*da0073e9SAndroid Build Coastguard Worker            c = Context(1)
138*da0073e9SAndroid Build Coastguard Worker
139*da0073e9SAndroid Build Coastguard Worker            with c as mult:
140*da0073e9SAndroid Build Coastguard Worker                y = x + mult
141*da0073e9SAndroid Build Coastguard Worker
142*da0073e9SAndroid Build Coastguard Worker            with c as mult:
143*da0073e9SAndroid Build Coastguard Worker                y *= mult
144*da0073e9SAndroid Build Coastguard Worker
145*da0073e9SAndroid Build Coastguard Worker            return y
146*da0073e9SAndroid Build Coastguard Worker
147*da0073e9SAndroid Build Coastguard Worker        def test_nested(x: torch.Tensor) -> torch.Tensor:
148*da0073e9SAndroid Build Coastguard Worker            """
149*da0073e9SAndroid Build Coastguard Worker            Test nested with-statements.
150*da0073e9SAndroid Build Coastguard Worker            """
151*da0073e9SAndroid Build Coastguard Worker            c = Context(1)
152*da0073e9SAndroid Build Coastguard Worker
153*da0073e9SAndroid Build Coastguard Worker            with c as m:
154*da0073e9SAndroid Build Coastguard Worker                with c as n:
155*da0073e9SAndroid Build Coastguard Worker                    y = x + n
156*da0073e9SAndroid Build Coastguard Worker
157*da0073e9SAndroid Build Coastguard Worker                y *= m
158*da0073e9SAndroid Build Coastguard Worker
159*da0073e9SAndroid Build Coastguard Worker            return y
160*da0073e9SAndroid Build Coastguard Worker
161*da0073e9SAndroid Build Coastguard Worker        def test_combined(x: torch.Tensor) -> torch.Tensor:
162*da0073e9SAndroid Build Coastguard Worker            """
163*da0073e9SAndroid Build Coastguard Worker            Test a with-statement with multiple with items.
164*da0073e9SAndroid Build Coastguard Worker            """
165*da0073e9SAndroid Build Coastguard Worker            c = Context(1)
166*da0073e9SAndroid Build Coastguard Worker            d = Context(2)
167*da0073e9SAndroid Build Coastguard Worker
168*da0073e9SAndroid Build Coastguard Worker            with c as m, d as n:
169*da0073e9SAndroid Build Coastguard Worker                y = x + (m + n)
170*da0073e9SAndroid Build Coastguard Worker
171*da0073e9SAndroid Build Coastguard Worker            return y
172*da0073e9SAndroid Build Coastguard Worker
173*da0073e9SAndroid Build Coastguard Worker        test_input = torch.randn(2, 2)
174*da0073e9SAndroid Build Coastguard Worker        test_context = Context(2)
175*da0073e9SAndroid Build Coastguard Worker        test_list = [2, 0, 1, 3, 0, 2]
176*da0073e9SAndroid Build Coastguard Worker
177*da0073e9SAndroid Build Coastguard Worker        self.checkScript(test_basic, (test_input,))
178*da0073e9SAndroid Build Coastguard Worker        self.checkScript(test_pass, (test_input,))
179*da0073e9SAndroid Build Coastguard Worker        self.checkScript(test_early_return, (test_input, test_context))
180*da0073e9SAndroid Build Coastguard Worker        self.checkScript(test_break, (test_input, test_context, test_list))
181*da0073e9SAndroid Build Coastguard Worker        self.checkScript(test_continue, (test_input, test_context, test_list))
182*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(test_context.count, 2)
183*da0073e9SAndroid Build Coastguard Worker        self.checkScript(test_serial, (test_input,))
184*da0073e9SAndroid Build Coastguard Worker        self.checkScript(test_nested, (test_input,))
185*da0073e9SAndroid Build Coastguard Worker        self.checkScript(test_combined, (test_input,))
186*da0073e9SAndroid Build Coastguard Worker
187*da0073e9SAndroid Build Coastguard Worker    def test_with_no_as(self):
188*da0073e9SAndroid Build Coastguard Worker        """
189*da0073e9SAndroid Build Coastguard Worker        Check that with statements that do not use the 'as' keyword to bind expressions
190*da0073e9SAndroid Build Coastguard Worker        to targets work as expected.
191*da0073e9SAndroid Build Coastguard Worker        """
192*da0073e9SAndroid Build Coastguard Worker
193*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
194*da0073e9SAndroid Build Coastguard Worker        class Context:
195*da0073e9SAndroid Build Coastguard Worker            """
196*da0073e9SAndroid Build Coastguard Worker            This class implements a basic context manager interface for use in
197*da0073e9SAndroid Build Coastguard Worker            the unit tests. Unlike Context, the stateful part of this class
198*da0073e9SAndroid Build Coastguard Worker            is a Tensor that is mutated in-place so that modifications made in the
199*da0073e9SAndroid Build Coastguard Worker            JIT interpreter are visible outside of it.
200*da0073e9SAndroid Build Coastguard Worker            """
201*da0073e9SAndroid Build Coastguard Worker
202*da0073e9SAndroid Build Coastguard Worker            def __init__(self, start: int):
203*da0073e9SAndroid Build Coastguard Worker                self.count = torch.tensor([start], dtype=torch.double)
204*da0073e9SAndroid Build Coastguard Worker
205*da0073e9SAndroid Build Coastguard Worker            def __enter__(self):
206*da0073e9SAndroid Build Coastguard Worker                self.count.add_(0.3)
207*da0073e9SAndroid Build Coastguard Worker                return self.count
208*da0073e9SAndroid Build Coastguard Worker
209*da0073e9SAndroid Build Coastguard Worker            def __exit__(self, type: Any, value: Any, tb: Any):
210*da0073e9SAndroid Build Coastguard Worker                self.count.sub_(0.3)
211*da0073e9SAndroid Build Coastguard Worker
212*da0073e9SAndroid Build Coastguard Worker        make_global(Context)
213*da0073e9SAndroid Build Coastguard Worker
214*da0073e9SAndroid Build Coastguard Worker        def test_basic(x: torch.Tensor) -> torch.Tensor:
215*da0073e9SAndroid Build Coastguard Worker            """Basic test with one with-statement."""
216*da0073e9SAndroid Build Coastguard Worker
217*da0073e9SAndroid Build Coastguard Worker            c = Context(1)
218*da0073e9SAndroid Build Coastguard Worker
219*da0073e9SAndroid Build Coastguard Worker            with c:
220*da0073e9SAndroid Build Coastguard Worker                y = x + c.count
221*da0073e9SAndroid Build Coastguard Worker
222*da0073e9SAndroid Build Coastguard Worker            y *= c.count
223*da0073e9SAndroid Build Coastguard Worker            return y
224*da0073e9SAndroid Build Coastguard Worker
225*da0073e9SAndroid Build Coastguard Worker        def test_pass(x: torch.Tensor) -> torch.Tensor:
226*da0073e9SAndroid Build Coastguard Worker            """
227*da0073e9SAndroid Build Coastguard Worker            Test with a pass statement inside a with-statement. Although
228*da0073e9SAndroid Build Coastguard Worker            the body of the with is empty, __enter__ and __exit__ should
229*da0073e9SAndroid Build Coastguard Worker            still be called.
230*da0073e9SAndroid Build Coastguard Worker            """
231*da0073e9SAndroid Build Coastguard Worker            c = Context(1)
232*da0073e9SAndroid Build Coastguard Worker
233*da0073e9SAndroid Build Coastguard Worker            with c:
234*da0073e9SAndroid Build Coastguard Worker                pass
235*da0073e9SAndroid Build Coastguard Worker
236*da0073e9SAndroid Build Coastguard Worker            x *= c.count
237*da0073e9SAndroid Build Coastguard Worker            return x
238*da0073e9SAndroid Build Coastguard Worker
239*da0073e9SAndroid Build Coastguard Worker        def test_early_return(x: torch.Tensor, c: Context) -> torch.Tensor:
240*da0073e9SAndroid Build Coastguard Worker            """
241*da0073e9SAndroid Build Coastguard Worker            Test that returning early from inside a with-statement works
242*da0073e9SAndroid Build Coastguard Worker            as expected.
243*da0073e9SAndroid Build Coastguard Worker            """
244*da0073e9SAndroid Build Coastguard Worker            with c:
245*da0073e9SAndroid Build Coastguard Worker                y = x + c.count
246*da0073e9SAndroid Build Coastguard Worker                return y
247*da0073e9SAndroid Build Coastguard Worker
248*da0073e9SAndroid Build Coastguard Worker            x = y + y
249*da0073e9SAndroid Build Coastguard Worker            return x
250*da0073e9SAndroid Build Coastguard Worker
251*da0073e9SAndroid Build Coastguard Worker        def test_conditional_early_return(x: torch.Tensor, c: Context) -> torch.Tensor:
252*da0073e9SAndroid Build Coastguard Worker            """
253*da0073e9SAndroid Build Coastguard Worker            Test that conditionally returning early from inside a with-statement works
254*da0073e9SAndroid Build Coastguard Worker            as expected.
255*da0073e9SAndroid Build Coastguard Worker            """
256*da0073e9SAndroid Build Coastguard Worker            with c:
257*da0073e9SAndroid Build Coastguard Worker                y = x + c.count
258*da0073e9SAndroid Build Coastguard Worker                if c.count > 0:
259*da0073e9SAndroid Build Coastguard Worker                    return y
260*da0073e9SAndroid Build Coastguard Worker
261*da0073e9SAndroid Build Coastguard Worker            x = y + y
262*da0073e9SAndroid Build Coastguard Worker            return x
263*da0073e9SAndroid Build Coastguard Worker
264*da0073e9SAndroid Build Coastguard Worker        def test_break(x: torch.Tensor, c: Context, l: List[int]) -> torch.Tensor:
265*da0073e9SAndroid Build Coastguard Worker            """
266*da0073e9SAndroid Build Coastguard Worker            Test that breaking early from inside a with-statement works
267*da0073e9SAndroid Build Coastguard Worker            as expected.
268*da0073e9SAndroid Build Coastguard Worker            """
269*da0073e9SAndroid Build Coastguard Worker            with c:
270*da0073e9SAndroid Build Coastguard Worker                for a in l:
271*da0073e9SAndroid Build Coastguard Worker                    if a == 0:
272*da0073e9SAndroid Build Coastguard Worker                        break
273*da0073e9SAndroid Build Coastguard Worker                    x += a * c.count
274*da0073e9SAndroid Build Coastguard Worker
275*da0073e9SAndroid Build Coastguard Worker            return x
276*da0073e9SAndroid Build Coastguard Worker
277*da0073e9SAndroid Build Coastguard Worker        def test_continue(x: torch.Tensor, c: Context, l: List[int]) -> torch.Tensor:
278*da0073e9SAndroid Build Coastguard Worker            """
279*da0073e9SAndroid Build Coastguard Worker            Test that using continue inside a with-statement works
280*da0073e9SAndroid Build Coastguard Worker            as expected.
281*da0073e9SAndroid Build Coastguard Worker            """
282*da0073e9SAndroid Build Coastguard Worker            with c:
283*da0073e9SAndroid Build Coastguard Worker                for a in l:
284*da0073e9SAndroid Build Coastguard Worker                    if a == 0:
285*da0073e9SAndroid Build Coastguard Worker                        continue
286*da0073e9SAndroid Build Coastguard Worker                    x += a * c.count
287*da0073e9SAndroid Build Coastguard Worker
288*da0073e9SAndroid Build Coastguard Worker            return x
289*da0073e9SAndroid Build Coastguard Worker
290*da0073e9SAndroid Build Coastguard Worker        def test_serial(x: torch.Tensor) -> torch.Tensor:
291*da0073e9SAndroid Build Coastguard Worker            """
292*da0073e9SAndroid Build Coastguard Worker            Test two with-statements in a row.
293*da0073e9SAndroid Build Coastguard Worker            """
294*da0073e9SAndroid Build Coastguard Worker            c = Context(1)
295*da0073e9SAndroid Build Coastguard Worker
296*da0073e9SAndroid Build Coastguard Worker            with c:
297*da0073e9SAndroid Build Coastguard Worker                y = x + c.count
298*da0073e9SAndroid Build Coastguard Worker
299*da0073e9SAndroid Build Coastguard Worker            with c:
300*da0073e9SAndroid Build Coastguard Worker                y *= c.count
301*da0073e9SAndroid Build Coastguard Worker
302*da0073e9SAndroid Build Coastguard Worker            return y
303*da0073e9SAndroid Build Coastguard Worker
304*da0073e9SAndroid Build Coastguard Worker        def test_nested(x: torch.Tensor) -> torch.Tensor:
305*da0073e9SAndroid Build Coastguard Worker            """
306*da0073e9SAndroid Build Coastguard Worker            Test nested with-statements.
307*da0073e9SAndroid Build Coastguard Worker            """
308*da0073e9SAndroid Build Coastguard Worker            c = Context(1)
309*da0073e9SAndroid Build Coastguard Worker
310*da0073e9SAndroid Build Coastguard Worker            with c:
311*da0073e9SAndroid Build Coastguard Worker                with c:
312*da0073e9SAndroid Build Coastguard Worker                    y = x + c.count
313*da0073e9SAndroid Build Coastguard Worker
314*da0073e9SAndroid Build Coastguard Worker                y *= c.count
315*da0073e9SAndroid Build Coastguard Worker
316*da0073e9SAndroid Build Coastguard Worker            return y
317*da0073e9SAndroid Build Coastguard Worker
318*da0073e9SAndroid Build Coastguard Worker        def test_combined(x: torch.Tensor) -> torch.Tensor:
319*da0073e9SAndroid Build Coastguard Worker            """
320*da0073e9SAndroid Build Coastguard Worker            Test a with-statement with multiple with items.
321*da0073e9SAndroid Build Coastguard Worker            """
322*da0073e9SAndroid Build Coastguard Worker            c = Context(1)
323*da0073e9SAndroid Build Coastguard Worker            d = Context(2)
324*da0073e9SAndroid Build Coastguard Worker
325*da0073e9SAndroid Build Coastguard Worker            with c, d:
326*da0073e9SAndroid Build Coastguard Worker                y = x + (c.count + d.count)
327*da0073e9SAndroid Build Coastguard Worker
328*da0073e9SAndroid Build Coastguard Worker            return y
329*da0073e9SAndroid Build Coastguard Worker
330*da0073e9SAndroid Build Coastguard Worker        test_input = torch.randn(2, 2)
331*da0073e9SAndroid Build Coastguard Worker        test_context = Context(2)
332*da0073e9SAndroid Build Coastguard Worker        test_list = [2, 0, 1, 3, 0, 2]
333*da0073e9SAndroid Build Coastguard Worker
334*da0073e9SAndroid Build Coastguard Worker        self.checkScript(test_basic, (test_input,))
335*da0073e9SAndroid Build Coastguard Worker        self.checkScript(test_pass, (test_input,))
336*da0073e9SAndroid Build Coastguard Worker        self.checkScript(test_early_return, (test_input, test_context))
337*da0073e9SAndroid Build Coastguard Worker        self.checkScript(test_break, (test_input, test_context, test_list))
338*da0073e9SAndroid Build Coastguard Worker        self.checkScript(test_continue, (test_input, test_context, test_list))
339*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(test_context.count, 2)
340*da0073e9SAndroid Build Coastguard Worker        self.checkScript(test_serial, (test_input,))
341*da0073e9SAndroid Build Coastguard Worker        self.checkScript(test_nested, (test_input,))
342*da0073e9SAndroid Build Coastguard Worker        self.checkScript(test_combined, (test_input,))
343*da0073e9SAndroid Build Coastguard Worker
344*da0073e9SAndroid Build Coastguard Worker    def test_with_exceptions(self):
345*da0073e9SAndroid Build Coastguard Worker        """
346*da0073e9SAndroid Build Coastguard Worker        Check that exceptions thrown in the bodies of with-statements are
347*da0073e9SAndroid Build Coastguard Worker        handled correctly.
348*da0073e9SAndroid Build Coastguard Worker        """
349*da0073e9SAndroid Build Coastguard Worker
350*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
351*da0073e9SAndroid Build Coastguard Worker        class Context:
352*da0073e9SAndroid Build Coastguard Worker            """
353*da0073e9SAndroid Build Coastguard Worker            This class implements a basic context manager interface for use in
354*da0073e9SAndroid Build Coastguard Worker            the unit tests. Unlike Context, the stateful part of this class
355*da0073e9SAndroid Build Coastguard Worker            is a Tensor that is mutated in-place so that modifications made in the
356*da0073e9SAndroid Build Coastguard Worker            JIT interpreter are visible outside of it.
357*da0073e9SAndroid Build Coastguard Worker            """
358*da0073e9SAndroid Build Coastguard Worker
359*da0073e9SAndroid Build Coastguard Worker            def __init__(self, start: int):
360*da0073e9SAndroid Build Coastguard Worker                self.count = torch.tensor([start], dtype=torch.double)
361*da0073e9SAndroid Build Coastguard Worker
362*da0073e9SAndroid Build Coastguard Worker            def __enter__(self):
363*da0073e9SAndroid Build Coastguard Worker                self.count.add_(0.3)
364*da0073e9SAndroid Build Coastguard Worker                return self.count
365*da0073e9SAndroid Build Coastguard Worker
366*da0073e9SAndroid Build Coastguard Worker            def __exit__(self, type: Any, value: Any, tb: Any):
367*da0073e9SAndroid Build Coastguard Worker                self.count.sub_(0.3)
368*da0073e9SAndroid Build Coastguard Worker
369*da0073e9SAndroid Build Coastguard Worker        make_global(Context)
370*da0073e9SAndroid Build Coastguard Worker
371*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
372*da0073e9SAndroid Build Coastguard Worker        def method_that_raises() -> torch.Tensor:
373*da0073e9SAndroid Build Coastguard Worker            raise Exception("raised exception")  # noqa: TRY002
374*da0073e9SAndroid Build Coastguard Worker
375*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
376*da0073e9SAndroid Build Coastguard Worker        def test_exception(x: torch.Tensor, c: Context) -> torch.Tensor:
377*da0073e9SAndroid Build Coastguard Worker            """
378*da0073e9SAndroid Build Coastguard Worker            Test the case in which an exception is thrown while executing the body of a with-statement.
379*da0073e9SAndroid Build Coastguard Worker            """
380*da0073e9SAndroid Build Coastguard Worker            with c as _:
381*da0073e9SAndroid Build Coastguard Worker                x += method_that_raises()
382*da0073e9SAndroid Build Coastguard Worker
383*da0073e9SAndroid Build Coastguard Worker            return x
384*da0073e9SAndroid Build Coastguard Worker
385*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
386*da0073e9SAndroid Build Coastguard Worker        def test_exception_nested(x: torch.Tensor, c: Context) -> torch.Tensor:
387*da0073e9SAndroid Build Coastguard Worker            """
388*da0073e9SAndroid Build Coastguard Worker            Test the case in which an exception is thrown while executing the body of a nested with-statement.
389*da0073e9SAndroid Build Coastguard Worker            """
390*da0073e9SAndroid Build Coastguard Worker            with c as _:
391*da0073e9SAndroid Build Coastguard Worker                with c as _:
392*da0073e9SAndroid Build Coastguard Worker                    x += method_that_raises()
393*da0073e9SAndroid Build Coastguard Worker
394*da0073e9SAndroid Build Coastguard Worker            return x
395*da0073e9SAndroid Build Coastguard Worker
396*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
397*da0073e9SAndroid Build Coastguard Worker        def with_that_raises(c: Context) -> torch.Tensor:
398*da0073e9SAndroid Build Coastguard Worker            a = torch.tensor([1])
399*da0073e9SAndroid Build Coastguard Worker
400*da0073e9SAndroid Build Coastguard Worker            with c as _:
401*da0073e9SAndroid Build Coastguard Worker                a += method_that_raises()
402*da0073e9SAndroid Build Coastguard Worker
403*da0073e9SAndroid Build Coastguard Worker            return a
404*da0073e9SAndroid Build Coastguard Worker
405*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
406*da0073e9SAndroid Build Coastguard Worker        def test_exception_fn_call(x: torch.Tensor, c: Context) -> torch.Tensor:
407*da0073e9SAndroid Build Coastguard Worker            """
408*da0073e9SAndroid Build Coastguard Worker            Test the case in which an exception is thrown while there are active with-statements in two different
409*da0073e9SAndroid Build Coastguard Worker            frames.
410*da0073e9SAndroid Build Coastguard Worker            """
411*da0073e9SAndroid Build Coastguard Worker            with c as _:
412*da0073e9SAndroid Build Coastguard Worker                x += with_that_raises(c)
413*da0073e9SAndroid Build Coastguard Worker
414*da0073e9SAndroid Build Coastguard Worker            return x
415*da0073e9SAndroid Build Coastguard Worker
416*da0073e9SAndroid Build Coastguard Worker        c = Context(1)
417*da0073e9SAndroid Build Coastguard Worker
418*da0073e9SAndroid Build Coastguard Worker        # checkScript and checkScriptRaisesRegex cannot be used because the string frontend will
419*da0073e9SAndroid Build Coastguard Worker        # not compile class types (of which Context, the context manager being used for this test
420*da0073e9SAndroid Build Coastguard Worker        # is one).
421*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegexWithHighlight(
422*da0073e9SAndroid Build Coastguard Worker            Exception, r"raised exception", 'raise Exception("raised exception'
423*da0073e9SAndroid Build Coastguard Worker        ):
424*da0073e9SAndroid Build Coastguard Worker            test_exception(torch.randn(2), c)
425*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(c.count, 1)
426*da0073e9SAndroid Build Coastguard Worker
427*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegexWithHighlight(
428*da0073e9SAndroid Build Coastguard Worker            Exception, r"raised exception", 'raise Exception("raised exception'
429*da0073e9SAndroid Build Coastguard Worker        ):
430*da0073e9SAndroid Build Coastguard Worker            test_exception_nested(torch.randn(2), c)
431*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(c.count, 1)
432*da0073e9SAndroid Build Coastguard Worker
433*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegexWithHighlight(
434*da0073e9SAndroid Build Coastguard Worker            Exception, r"raised exception", 'raise Exception("raised exception'
435*da0073e9SAndroid Build Coastguard Worker        ):
436*da0073e9SAndroid Build Coastguard Worker            test_exception_fn_call(torch.randn(2), c)
437*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(c.count, 1)
438*da0073e9SAndroid Build Coastguard Worker
439*da0073e9SAndroid Build Coastguard Worker    def test_with_errors(self):
440*da0073e9SAndroid Build Coastguard Worker        """
441*da0073e9SAndroid Build Coastguard Worker        Check that errors related to with-statements are detected and reported correctly.
442*da0073e9SAndroid Build Coastguard Worker        """
443*da0073e9SAndroid Build Coastguard Worker
444*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
445*da0073e9SAndroid Build Coastguard Worker        class NoEnterNoExit:
446*da0073e9SAndroid Build Coastguard Worker            """
447*da0073e9SAndroid Build Coastguard Worker            This class is missing __enter__ and __exit__ methods.
448*da0073e9SAndroid Build Coastguard Worker            """
449*da0073e9SAndroid Build Coastguard Worker
450*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
451*da0073e9SAndroid Build Coastguard Worker                self.count = 1
452*da0073e9SAndroid Build Coastguard Worker
453*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
454*da0073e9SAndroid Build Coastguard Worker        class BadEnter:
455*da0073e9SAndroid Build Coastguard Worker            """
456*da0073e9SAndroid Build Coastguard Worker            This class has an __enter__ method with an incorrect signature.
457*da0073e9SAndroid Build Coastguard Worker            """
458*da0073e9SAndroid Build Coastguard Worker
459*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
460*da0073e9SAndroid Build Coastguard Worker                self.count = 1
461*da0073e9SAndroid Build Coastguard Worker
462*da0073e9SAndroid Build Coastguard Worker            def __enter__(self, incr: int):  # noqa: PLE0302
463*da0073e9SAndroid Build Coastguard Worker                self.count += incr
464*da0073e9SAndroid Build Coastguard Worker
465*da0073e9SAndroid Build Coastguard Worker            def __exit__(self, type: Any, value: Any, tb: Any):
466*da0073e9SAndroid Build Coastguard Worker                pass
467*da0073e9SAndroid Build Coastguard Worker
468*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
469*da0073e9SAndroid Build Coastguard Worker        class BadExit:
470*da0073e9SAndroid Build Coastguard Worker            """
471*da0073e9SAndroid Build Coastguard Worker            This class has an __exit__ method with an incorrect signature.
472*da0073e9SAndroid Build Coastguard Worker            """
473*da0073e9SAndroid Build Coastguard Worker
474*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
475*da0073e9SAndroid Build Coastguard Worker                self.count = 1
476*da0073e9SAndroid Build Coastguard Worker
477*da0073e9SAndroid Build Coastguard Worker            def __enter__(self):
478*da0073e9SAndroid Build Coastguard Worker                self.count += 1
479*da0073e9SAndroid Build Coastguard Worker
480*da0073e9SAndroid Build Coastguard Worker            def __exit__(self, type: Any, value: Any):  # noqa: PLE0302
481*da0073e9SAndroid Build Coastguard Worker                pass
482*da0073e9SAndroid Build Coastguard Worker
483*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
484*da0073e9SAndroid Build Coastguard Worker        class ExitIncorrectTypes:
485*da0073e9SAndroid Build Coastguard Worker            """
486*da0073e9SAndroid Build Coastguard Worker            This class has an __exit__ method with unsupported argument types.
487*da0073e9SAndroid Build Coastguard Worker            """
488*da0073e9SAndroid Build Coastguard Worker
489*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
490*da0073e9SAndroid Build Coastguard Worker                self.count = 1
491*da0073e9SAndroid Build Coastguard Worker
492*da0073e9SAndroid Build Coastguard Worker            def __enter__(self):
493*da0073e9SAndroid Build Coastguard Worker                self.count += 1
494*da0073e9SAndroid Build Coastguard Worker
495*da0073e9SAndroid Build Coastguard Worker            def __exit__(self, type: Any, value: int, tb: int):
496*da0073e9SAndroid Build Coastguard Worker                pass
497*da0073e9SAndroid Build Coastguard Worker
498*da0073e9SAndroid Build Coastguard Worker        def test_no_enter_no_exit(x: torch.Tensor, cm: NoEnterNoExit) -> torch.Tensor:
499*da0073e9SAndroid Build Coastguard Worker            with cm as _:
500*da0073e9SAndroid Build Coastguard Worker                pass
501*da0073e9SAndroid Build Coastguard Worker
502*da0073e9SAndroid Build Coastguard Worker            return x
503*da0073e9SAndroid Build Coastguard Worker
504*da0073e9SAndroid Build Coastguard Worker        def test_bad_enter(x: torch.Tensor, cm: BadEnter) -> torch.Tensor:
505*da0073e9SAndroid Build Coastguard Worker            with cm as _:
506*da0073e9SAndroid Build Coastguard Worker                pass
507*da0073e9SAndroid Build Coastguard Worker
508*da0073e9SAndroid Build Coastguard Worker            return x
509*da0073e9SAndroid Build Coastguard Worker
510*da0073e9SAndroid Build Coastguard Worker        def test_bad_exit(x: torch.Tensor, cm: BadExit) -> torch.Tensor:
511*da0073e9SAndroid Build Coastguard Worker            with cm as _:
512*da0073e9SAndroid Build Coastguard Worker                pass
513*da0073e9SAndroid Build Coastguard Worker
514*da0073e9SAndroid Build Coastguard Worker            return x
515*da0073e9SAndroid Build Coastguard Worker
516*da0073e9SAndroid Build Coastguard Worker        def test_exit_incorrect_types(
517*da0073e9SAndroid Build Coastguard Worker            x: torch.Tensor, cm: ExitIncorrectTypes
518*da0073e9SAndroid Build Coastguard Worker        ) -> torch.Tensor:
519*da0073e9SAndroid Build Coastguard Worker            with cm as _:
520*da0073e9SAndroid Build Coastguard Worker                pass
521*da0073e9SAndroid Build Coastguard Worker
522*da0073e9SAndroid Build Coastguard Worker            return x
523*da0073e9SAndroid Build Coastguard Worker
524*da0073e9SAndroid Build Coastguard Worker        def test_enter_without_object():
525*da0073e9SAndroid Build Coastguard Worker            with "not_object" as obj:
526*da0073e9SAndroid Build Coastguard Worker                pass
527*da0073e9SAndroid Build Coastguard Worker
528*da0073e9SAndroid Build Coastguard Worker        test_tensor = torch.randn(5, dtype=torch.double)
529*da0073e9SAndroid Build Coastguard Worker
530*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegexWithHighlight(
531*da0073e9SAndroid Build Coastguard Worker            RuntimeError, r"does not define __enter__ and __exit__ methods", "cm"
532*da0073e9SAndroid Build Coastguard Worker        ):
533*da0073e9SAndroid Build Coastguard Worker            self.checkScript(test_no_enter_no_exit, (test_tensor, NoEnterNoExit()))
534*da0073e9SAndroid Build Coastguard Worker
535*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegexWithHighlight(
536*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
537*da0073e9SAndroid Build Coastguard Worker            r"__enter__ must have only one argument and one return value",
538*da0073e9SAndroid Build Coastguard Worker            "cm",
539*da0073e9SAndroid Build Coastguard Worker        ):
540*da0073e9SAndroid Build Coastguard Worker            self.checkScript(test_bad_enter, (test_tensor, BadEnter()))
541*da0073e9SAndroid Build Coastguard Worker
542*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegexWithHighlight(
543*da0073e9SAndroid Build Coastguard Worker            RuntimeError, r"__exit__ must have four arguments", "cm"
544*da0073e9SAndroid Build Coastguard Worker        ):
545*da0073e9SAndroid Build Coastguard Worker            self.checkScript(test_bad_exit, (test_tensor, BadExit()))
546*da0073e9SAndroid Build Coastguard Worker
547*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegexWithHighlight(
548*da0073e9SAndroid Build Coastguard Worker            RuntimeError, r"argument 2 of __exit__ must have Any type", "cm"
549*da0073e9SAndroid Build Coastguard Worker        ):
550*da0073e9SAndroid Build Coastguard Worker            self.checkScript(
551*da0073e9SAndroid Build Coastguard Worker                test_exit_incorrect_types, (test_tensor, ExitIncorrectTypes())
552*da0073e9SAndroid Build Coastguard Worker            )
553*da0073e9SAndroid Build Coastguard Worker
554*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegexWithHighlight(
555*da0073e9SAndroid Build Coastguard Worker            RuntimeError, r"must return an object", '"not_object"'
556*da0073e9SAndroid Build Coastguard Worker        ):
557*da0073e9SAndroid Build Coastguard Worker            self.checkScript(test_enter_without_object, ())
558*da0073e9SAndroid Build Coastguard Worker
559*da0073e9SAndroid Build Coastguard Worker    def test_with_no_grad(self):
560*da0073e9SAndroid Build Coastguard Worker        """
561*da0073e9SAndroid Build Coastguard Worker        Check that torch.no_grad() works. Most of these are adapted from
562*da0073e9SAndroid Build Coastguard Worker        corresponding tests for eager-mode no_grad.
563*da0073e9SAndroid Build Coastguard Worker        """
564*da0073e9SAndroid Build Coastguard Worker
565*da0073e9SAndroid Build Coastguard Worker        # Basic no_grad test.
566*da0073e9SAndroid Build Coastguard Worker        def test_no_grad(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
567*da0073e9SAndroid Build Coastguard Worker            with torch.no_grad():
568*da0073e9SAndroid Build Coastguard Worker                w = x + y
569*da0073e9SAndroid Build Coastguard Worker
570*da0073e9SAndroid Build Coastguard Worker            return w
571*da0073e9SAndroid Build Coastguard Worker
572*da0073e9SAndroid Build Coastguard Worker        s = torch.jit.script(test_no_grad)
573*da0073e9SAndroid Build Coastguard Worker        x = torch.ones(5, 5, requires_grad=True)
574*da0073e9SAndroid Build Coastguard Worker        y = torch.ones(5, 5) * 4
575*da0073e9SAndroid Build Coastguard Worker        w = s(x, y)
576*da0073e9SAndroid Build Coastguard Worker
577*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(w.requires_grad)
578*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(RuntimeError, lambda: w.backward(torch.ones(5, 5)))
579*da0073e9SAndroid Build Coastguard Worker        self.assertIsNone(w.grad_fn)
580*da0073e9SAndroid Build Coastguard Worker
581*da0073e9SAndroid Build Coastguard Worker        # Test assignment of a grad-less Tensor to a Tensor with gradients
582*da0073e9SAndroid Build Coastguard Worker        # in a no_grad block.
583*da0073e9SAndroid Build Coastguard Worker        def test_no_grad_assignment(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
584*da0073e9SAndroid Build Coastguard Worker            with torch.no_grad():
585*da0073e9SAndroid Build Coastguard Worker                x[0] = y
586*da0073e9SAndroid Build Coastguard Worker
587*da0073e9SAndroid Build Coastguard Worker            return x
588*da0073e9SAndroid Build Coastguard Worker
589*da0073e9SAndroid Build Coastguard Worker        s = torch.jit.script(test_no_grad_assignment)
590*da0073e9SAndroid Build Coastguard Worker        z = torch.randn(5)
591*da0073e9SAndroid Build Coastguard Worker        w = s(x, z)
592*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(w.requires_grad)
593*da0073e9SAndroid Build Coastguard Worker        self.assertIsNone(w.grad_fn)
594*da0073e9SAndroid Build Coastguard Worker
595*da0073e9SAndroid Build Coastguard Worker        # Check that @torch.jit.ignored functions respect no_grad when it is
596*da0073e9SAndroid Build Coastguard Worker        # called in JIT mode.
597*da0073e9SAndroid Build Coastguard Worker        class NoGradModule(torch.nn.Module):
598*da0073e9SAndroid Build Coastguard Worker            @torch.jit.ignore
599*da0073e9SAndroid Build Coastguard Worker            def adder(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
600*da0073e9SAndroid Build Coastguard Worker                w = x + y
601*da0073e9SAndroid Build Coastguard Worker                return w
602*da0073e9SAndroid Build Coastguard Worker
603*da0073e9SAndroid Build Coastguard Worker            def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
604*da0073e9SAndroid Build Coastguard Worker                with torch.no_grad():
605*da0073e9SAndroid Build Coastguard Worker                    w = self.adder(x, y)
606*da0073e9SAndroid Build Coastguard Worker
607*da0073e9SAndroid Build Coastguard Worker                return w
608*da0073e9SAndroid Build Coastguard Worker
609*da0073e9SAndroid Build Coastguard Worker        s = torch.jit.script(NoGradModule())
610*da0073e9SAndroid Build Coastguard Worker        w = s(x, y)
611*da0073e9SAndroid Build Coastguard Worker
612*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(w.requires_grad)
613*da0073e9SAndroid Build Coastguard Worker
614*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("Torchdynamo cannot correctly handle profiler.profile calls")
615*da0073e9SAndroid Build Coastguard Worker    def test_with_record_function(self):
616*da0073e9SAndroid Build Coastguard Worker        """
617*da0073e9SAndroid Build Coastguard Worker        Check that torch.autograd.profiler.record_function context manager is
618*da0073e9SAndroid Build Coastguard Worker        torchscriptable.
619*da0073e9SAndroid Build Coastguard Worker        """
620*da0073e9SAndroid Build Coastguard Worker
621*da0073e9SAndroid Build Coastguard Worker        def with_rf(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
622*da0073e9SAndroid Build Coastguard Worker            with torch.autograd.profiler.record_function("foo"):
623*da0073e9SAndroid Build Coastguard Worker                # Nested record_function.
624*da0073e9SAndroid Build Coastguard Worker                with torch.autograd.profiler.record_function("nested"):
625*da0073e9SAndroid Build Coastguard Worker                    a = x + y
626*da0073e9SAndroid Build Coastguard Worker            return a
627*da0073e9SAndroid Build Coastguard Worker
628*da0073e9SAndroid Build Coastguard Worker        scripted = torch.jit.script(with_rf)
629*da0073e9SAndroid Build Coastguard Worker        x, y = torch.ones(2), torch.ones(2)
630*da0073e9SAndroid Build Coastguard Worker        with torch.autograd.profiler.profile() as p:
631*da0073e9SAndroid Build Coastguard Worker            scripted(x, y)
632*da0073e9SAndroid Build Coastguard Worker
633*da0073e9SAndroid Build Coastguard Worker        # Need to call below to populate CPU children.
634*da0073e9SAndroid Build Coastguard Worker        p.key_averages()
635*da0073e9SAndroid Build Coastguard Worker        function_events = p.function_events
636*da0073e9SAndroid Build Coastguard Worker        # Event with name "foo" should be recorded.
637*da0073e9SAndroid Build Coastguard Worker        rf_events = [evt for evt in function_events if evt.name == "foo"]
638*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(rf_events), 1)
639*da0073e9SAndroid Build Coastguard Worker        rf_event = rf_events[0]
640*da0073e9SAndroid Build Coastguard Worker        child_events = rf_event.cpu_children
641*da0073e9SAndroid Build Coastguard Worker        # Ensure we find nested record_function event
642*da0073e9SAndroid Build Coastguard Worker        self.assertTrue("nested" in (child.name for child in child_events))
643*da0073e9SAndroid Build Coastguard Worker        nested_function_event = [
644*da0073e9SAndroid Build Coastguard Worker            evt for evt in function_events if evt.name == "nested"
645*da0073e9SAndroid Build Coastguard Worker        ][0]
646*da0073e9SAndroid Build Coastguard Worker        # Nested record function should have child "aten::add"
647*da0073e9SAndroid Build Coastguard Worker        nested_child_events = nested_function_event.cpu_children
648*da0073e9SAndroid Build Coastguard Worker        self.assertTrue("aten::add" in (child.name for child in nested_child_events))
649