1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["oncall: jit"] 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerimport os 4*da0073e9SAndroid Build Coastguard Workerimport sys 5*da0073e9SAndroid Build Coastguard Workerfrom typing import Any, List 6*da0073e9SAndroid Build Coastguard Worker 7*da0073e9SAndroid Build Coastguard Workerimport torch 8*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import skipIfTorchDynamo 9*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.jit_utils import JitTestCase, make_global 10*da0073e9SAndroid Build Coastguard Worker 11*da0073e9SAndroid Build Coastguard Worker 12*da0073e9SAndroid Build Coastguard Worker# Make the helper files in test/ importable 13*da0073e9SAndroid Build Coastguard Workerpytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 14*da0073e9SAndroid Build Coastguard Workersys.path.append(pytorch_test_dir) 15*da0073e9SAndroid Build Coastguard Worker 16*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 17*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 18*da0073e9SAndroid Build Coastguard Worker "This test file is not meant to be run directly, use:\n\n" 19*da0073e9SAndroid Build Coastguard Worker "\tpython test/test_jit.py TESTNAME\n\n" 20*da0073e9SAndroid Build Coastguard Worker "instead." 21*da0073e9SAndroid Build Coastguard Worker ) 22*da0073e9SAndroid Build Coastguard Worker 23*da0073e9SAndroid Build Coastguard Worker 24*da0073e9SAndroid Build Coastguard Workerclass TestWith(JitTestCase): 25*da0073e9SAndroid Build Coastguard Worker """ 26*da0073e9SAndroid Build Coastguard Worker A suite of tests for with statements. 27*da0073e9SAndroid Build Coastguard Worker """ 28*da0073e9SAndroid Build Coastguard Worker 29*da0073e9SAndroid Build Coastguard Worker def test_with_as(self): 30*da0073e9SAndroid Build Coastguard Worker """ 31*da0073e9SAndroid Build Coastguard Worker Check that with statements that use the 'as' keyword to bind expressions 32*da0073e9SAndroid Build Coastguard Worker to targets work as expected. 33*da0073e9SAndroid Build Coastguard Worker """ 34*da0073e9SAndroid Build Coastguard Worker 35*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 36*da0073e9SAndroid Build Coastguard Worker class Context: 37*da0073e9SAndroid Build Coastguard Worker """ 38*da0073e9SAndroid Build Coastguard Worker This class implements a basic context manager interface for use in 39*da0073e9SAndroid Build Coastguard Worker the unit tests. Unlike Context, the stateful part of this class 40*da0073e9SAndroid Build Coastguard Worker is a Tensor that is mutated in-place so that modifications made in the 41*da0073e9SAndroid Build Coastguard Worker JIT interpreter are visible outside of it. 42*da0073e9SAndroid Build Coastguard Worker """ 43*da0073e9SAndroid Build Coastguard Worker 44*da0073e9SAndroid Build Coastguard Worker def __init__(self, start: int): 45*da0073e9SAndroid Build Coastguard Worker self.count = torch.tensor([start], dtype=torch.double) 46*da0073e9SAndroid Build Coastguard Worker 47*da0073e9SAndroid Build Coastguard Worker def __enter__(self): 48*da0073e9SAndroid Build Coastguard Worker self.count.add_(0.3) 49*da0073e9SAndroid Build Coastguard Worker return self.count 50*da0073e9SAndroid Build Coastguard Worker 51*da0073e9SAndroid Build Coastguard Worker def __exit__(self, type: Any, value: Any, tb: Any) -> bool: 52*da0073e9SAndroid Build Coastguard Worker self.count.sub_(0.3) 53*da0073e9SAndroid Build Coastguard Worker return True 54*da0073e9SAndroid Build Coastguard Worker 55*da0073e9SAndroid Build Coastguard Worker make_global(Context) 56*da0073e9SAndroid Build Coastguard Worker 57*da0073e9SAndroid Build Coastguard Worker def test_basic(x: torch.Tensor) -> torch.Tensor: 58*da0073e9SAndroid Build Coastguard Worker """Basic test with one with-statement.""" 59*da0073e9SAndroid Build Coastguard Worker 60*da0073e9SAndroid Build Coastguard Worker c = Context(1) 61*da0073e9SAndroid Build Coastguard Worker 62*da0073e9SAndroid Build Coastguard Worker with c as mult: 63*da0073e9SAndroid Build Coastguard Worker y = x + mult 64*da0073e9SAndroid Build Coastguard Worker 65*da0073e9SAndroid Build Coastguard Worker y *= c.count 66*da0073e9SAndroid Build Coastguard Worker return y 67*da0073e9SAndroid Build Coastguard Worker 68*da0073e9SAndroid Build Coastguard Worker def test_pass(x: torch.Tensor) -> torch.Tensor: 69*da0073e9SAndroid Build Coastguard Worker """ 70*da0073e9SAndroid Build Coastguard Worker Test with a pass statement inside a with-statement. Although 71*da0073e9SAndroid Build Coastguard Worker the body of the with is empty, __enter__ and __exit__ should 72*da0073e9SAndroid Build Coastguard Worker still be called. 73*da0073e9SAndroid Build Coastguard Worker """ 74*da0073e9SAndroid Build Coastguard Worker c = Context(1) 75*da0073e9SAndroid Build Coastguard Worker 76*da0073e9SAndroid Build Coastguard Worker with c as mult: 77*da0073e9SAndroid Build Coastguard Worker pass 78*da0073e9SAndroid Build Coastguard Worker 79*da0073e9SAndroid Build Coastguard Worker x *= c.count 80*da0073e9SAndroid Build Coastguard Worker return x 81*da0073e9SAndroid Build Coastguard Worker 82*da0073e9SAndroid Build Coastguard Worker def test_early_return(x: torch.Tensor, c: Context) -> torch.Tensor: 83*da0073e9SAndroid Build Coastguard Worker """ 84*da0073e9SAndroid Build Coastguard Worker Test that returning early from inside a with-statement works 85*da0073e9SAndroid Build Coastguard Worker as expected. 86*da0073e9SAndroid Build Coastguard Worker """ 87*da0073e9SAndroid Build Coastguard Worker with c as mult: 88*da0073e9SAndroid Build Coastguard Worker y = x + mult 89*da0073e9SAndroid Build Coastguard Worker return y 90*da0073e9SAndroid Build Coastguard Worker 91*da0073e9SAndroid Build Coastguard Worker x = y + y 92*da0073e9SAndroid Build Coastguard Worker return x 93*da0073e9SAndroid Build Coastguard Worker 94*da0073e9SAndroid Build Coastguard Worker def test_conditional_early_return(x: torch.Tensor, c: Context) -> torch.Tensor: 95*da0073e9SAndroid Build Coastguard Worker """ 96*da0073e9SAndroid Build Coastguard Worker Test that conditionally returning early from inside a with-statement works 97*da0073e9SAndroid Build Coastguard Worker as expected. 98*da0073e9SAndroid Build Coastguard Worker """ 99*da0073e9SAndroid Build Coastguard Worker with c as mult: 100*da0073e9SAndroid Build Coastguard Worker y = x + mult 101*da0073e9SAndroid Build Coastguard Worker if mult > 0: 102*da0073e9SAndroid Build Coastguard Worker return y 103*da0073e9SAndroid Build Coastguard Worker 104*da0073e9SAndroid Build Coastguard Worker x = y + y 105*da0073e9SAndroid Build Coastguard Worker return x 106*da0073e9SAndroid Build Coastguard Worker 107*da0073e9SAndroid Build Coastguard Worker def test_break(x: torch.Tensor, c: Context, l: List[int]) -> torch.Tensor: 108*da0073e9SAndroid Build Coastguard Worker """ 109*da0073e9SAndroid Build Coastguard Worker Test that breaking early from inside a with-statement works 110*da0073e9SAndroid Build Coastguard Worker as expected. 111*da0073e9SAndroid Build Coastguard Worker """ 112*da0073e9SAndroid Build Coastguard Worker with c as mult: 113*da0073e9SAndroid Build Coastguard Worker for a in l: 114*da0073e9SAndroid Build Coastguard Worker if a == 0: 115*da0073e9SAndroid Build Coastguard Worker break 116*da0073e9SAndroid Build Coastguard Worker x += a * mult 117*da0073e9SAndroid Build Coastguard Worker 118*da0073e9SAndroid Build Coastguard Worker return x 119*da0073e9SAndroid Build Coastguard Worker 120*da0073e9SAndroid Build Coastguard Worker def test_continue(x: torch.Tensor, c: Context, l: List[int]) -> torch.Tensor: 121*da0073e9SAndroid Build Coastguard Worker """ 122*da0073e9SAndroid Build Coastguard Worker Test that using continue inside a with-statement works 123*da0073e9SAndroid Build Coastguard Worker as expected. 124*da0073e9SAndroid Build Coastguard Worker """ 125*da0073e9SAndroid Build Coastguard Worker with c as mult: 126*da0073e9SAndroid Build Coastguard Worker for a in l: 127*da0073e9SAndroid Build Coastguard Worker if a == 0: 128*da0073e9SAndroid Build Coastguard Worker continue 129*da0073e9SAndroid Build Coastguard Worker x += a * mult 130*da0073e9SAndroid Build Coastguard Worker 131*da0073e9SAndroid Build Coastguard Worker return x 132*da0073e9SAndroid Build Coastguard Worker 133*da0073e9SAndroid Build Coastguard Worker def test_serial(x: torch.Tensor) -> torch.Tensor: 134*da0073e9SAndroid Build Coastguard Worker """ 135*da0073e9SAndroid Build Coastguard Worker Test two with-statements in a row. 136*da0073e9SAndroid Build Coastguard Worker """ 137*da0073e9SAndroid Build Coastguard Worker c = Context(1) 138*da0073e9SAndroid Build Coastguard Worker 139*da0073e9SAndroid Build Coastguard Worker with c as mult: 140*da0073e9SAndroid Build Coastguard Worker y = x + mult 141*da0073e9SAndroid Build Coastguard Worker 142*da0073e9SAndroid Build Coastguard Worker with c as mult: 143*da0073e9SAndroid Build Coastguard Worker y *= mult 144*da0073e9SAndroid Build Coastguard Worker 145*da0073e9SAndroid Build Coastguard Worker return y 146*da0073e9SAndroid Build Coastguard Worker 147*da0073e9SAndroid Build Coastguard Worker def test_nested(x: torch.Tensor) -> torch.Tensor: 148*da0073e9SAndroid Build Coastguard Worker """ 149*da0073e9SAndroid Build Coastguard Worker Test nested with-statements. 150*da0073e9SAndroid Build Coastguard Worker """ 151*da0073e9SAndroid Build Coastguard Worker c = Context(1) 152*da0073e9SAndroid Build Coastguard Worker 153*da0073e9SAndroid Build Coastguard Worker with c as m: 154*da0073e9SAndroid Build Coastguard Worker with c as n: 155*da0073e9SAndroid Build Coastguard Worker y = x + n 156*da0073e9SAndroid Build Coastguard Worker 157*da0073e9SAndroid Build Coastguard Worker y *= m 158*da0073e9SAndroid Build Coastguard Worker 159*da0073e9SAndroid Build Coastguard Worker return y 160*da0073e9SAndroid Build Coastguard Worker 161*da0073e9SAndroid Build Coastguard Worker def test_combined(x: torch.Tensor) -> torch.Tensor: 162*da0073e9SAndroid Build Coastguard Worker """ 163*da0073e9SAndroid Build Coastguard Worker Test a with-statement with multiple with items. 164*da0073e9SAndroid Build Coastguard Worker """ 165*da0073e9SAndroid Build Coastguard Worker c = Context(1) 166*da0073e9SAndroid Build Coastguard Worker d = Context(2) 167*da0073e9SAndroid Build Coastguard Worker 168*da0073e9SAndroid Build Coastguard Worker with c as m, d as n: 169*da0073e9SAndroid Build Coastguard Worker y = x + (m + n) 170*da0073e9SAndroid Build Coastguard Worker 171*da0073e9SAndroid Build Coastguard Worker return y 172*da0073e9SAndroid Build Coastguard Worker 173*da0073e9SAndroid Build Coastguard Worker test_input = torch.randn(2, 2) 174*da0073e9SAndroid Build Coastguard Worker test_context = Context(2) 175*da0073e9SAndroid Build Coastguard Worker test_list = [2, 0, 1, 3, 0, 2] 176*da0073e9SAndroid Build Coastguard Worker 177*da0073e9SAndroid Build Coastguard Worker self.checkScript(test_basic, (test_input,)) 178*da0073e9SAndroid Build Coastguard Worker self.checkScript(test_pass, (test_input,)) 179*da0073e9SAndroid Build Coastguard Worker self.checkScript(test_early_return, (test_input, test_context)) 180*da0073e9SAndroid Build Coastguard Worker self.checkScript(test_break, (test_input, test_context, test_list)) 181*da0073e9SAndroid Build Coastguard Worker self.checkScript(test_continue, (test_input, test_context, test_list)) 182*da0073e9SAndroid Build Coastguard Worker self.assertEqual(test_context.count, 2) 183*da0073e9SAndroid Build Coastguard Worker self.checkScript(test_serial, (test_input,)) 184*da0073e9SAndroid Build Coastguard Worker self.checkScript(test_nested, (test_input,)) 185*da0073e9SAndroid Build Coastguard Worker self.checkScript(test_combined, (test_input,)) 186*da0073e9SAndroid Build Coastguard Worker 187*da0073e9SAndroid Build Coastguard Worker def test_with_no_as(self): 188*da0073e9SAndroid Build Coastguard Worker """ 189*da0073e9SAndroid Build Coastguard Worker Check that with statements that do not use the 'as' keyword to bind expressions 190*da0073e9SAndroid Build Coastguard Worker to targets work as expected. 191*da0073e9SAndroid Build Coastguard Worker """ 192*da0073e9SAndroid Build Coastguard Worker 193*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 194*da0073e9SAndroid Build Coastguard Worker class Context: 195*da0073e9SAndroid Build Coastguard Worker """ 196*da0073e9SAndroid Build Coastguard Worker This class implements a basic context manager interface for use in 197*da0073e9SAndroid Build Coastguard Worker the unit tests. Unlike Context, the stateful part of this class 198*da0073e9SAndroid Build Coastguard Worker is a Tensor that is mutated in-place so that modifications made in the 199*da0073e9SAndroid Build Coastguard Worker JIT interpreter are visible outside of it. 200*da0073e9SAndroid Build Coastguard Worker """ 201*da0073e9SAndroid Build Coastguard Worker 202*da0073e9SAndroid Build Coastguard Worker def __init__(self, start: int): 203*da0073e9SAndroid Build Coastguard Worker self.count = torch.tensor([start], dtype=torch.double) 204*da0073e9SAndroid Build Coastguard Worker 205*da0073e9SAndroid Build Coastguard Worker def __enter__(self): 206*da0073e9SAndroid Build Coastguard Worker self.count.add_(0.3) 207*da0073e9SAndroid Build Coastguard Worker return self.count 208*da0073e9SAndroid Build Coastguard Worker 209*da0073e9SAndroid Build Coastguard Worker def __exit__(self, type: Any, value: Any, tb: Any): 210*da0073e9SAndroid Build Coastguard Worker self.count.sub_(0.3) 211*da0073e9SAndroid Build Coastguard Worker 212*da0073e9SAndroid Build Coastguard Worker make_global(Context) 213*da0073e9SAndroid Build Coastguard Worker 214*da0073e9SAndroid Build Coastguard Worker def test_basic(x: torch.Tensor) -> torch.Tensor: 215*da0073e9SAndroid Build Coastguard Worker """Basic test with one with-statement.""" 216*da0073e9SAndroid Build Coastguard Worker 217*da0073e9SAndroid Build Coastguard Worker c = Context(1) 218*da0073e9SAndroid Build Coastguard Worker 219*da0073e9SAndroid Build Coastguard Worker with c: 220*da0073e9SAndroid Build Coastguard Worker y = x + c.count 221*da0073e9SAndroid Build Coastguard Worker 222*da0073e9SAndroid Build Coastguard Worker y *= c.count 223*da0073e9SAndroid Build Coastguard Worker return y 224*da0073e9SAndroid Build Coastguard Worker 225*da0073e9SAndroid Build Coastguard Worker def test_pass(x: torch.Tensor) -> torch.Tensor: 226*da0073e9SAndroid Build Coastguard Worker """ 227*da0073e9SAndroid Build Coastguard Worker Test with a pass statement inside a with-statement. Although 228*da0073e9SAndroid Build Coastguard Worker the body of the with is empty, __enter__ and __exit__ should 229*da0073e9SAndroid Build Coastguard Worker still be called. 230*da0073e9SAndroid Build Coastguard Worker """ 231*da0073e9SAndroid Build Coastguard Worker c = Context(1) 232*da0073e9SAndroid Build Coastguard Worker 233*da0073e9SAndroid Build Coastguard Worker with c: 234*da0073e9SAndroid Build Coastguard Worker pass 235*da0073e9SAndroid Build Coastguard Worker 236*da0073e9SAndroid Build Coastguard Worker x *= c.count 237*da0073e9SAndroid Build Coastguard Worker return x 238*da0073e9SAndroid Build Coastguard Worker 239*da0073e9SAndroid Build Coastguard Worker def test_early_return(x: torch.Tensor, c: Context) -> torch.Tensor: 240*da0073e9SAndroid Build Coastguard Worker """ 241*da0073e9SAndroid Build Coastguard Worker Test that returning early from inside a with-statement works 242*da0073e9SAndroid Build Coastguard Worker as expected. 243*da0073e9SAndroid Build Coastguard Worker """ 244*da0073e9SAndroid Build Coastguard Worker with c: 245*da0073e9SAndroid Build Coastguard Worker y = x + c.count 246*da0073e9SAndroid Build Coastguard Worker return y 247*da0073e9SAndroid Build Coastguard Worker 248*da0073e9SAndroid Build Coastguard Worker x = y + y 249*da0073e9SAndroid Build Coastguard Worker return x 250*da0073e9SAndroid Build Coastguard Worker 251*da0073e9SAndroid Build Coastguard Worker def test_conditional_early_return(x: torch.Tensor, c: Context) -> torch.Tensor: 252*da0073e9SAndroid Build Coastguard Worker """ 253*da0073e9SAndroid Build Coastguard Worker Test that conditionally returning early from inside a with-statement works 254*da0073e9SAndroid Build Coastguard Worker as expected. 255*da0073e9SAndroid Build Coastguard Worker """ 256*da0073e9SAndroid Build Coastguard Worker with c: 257*da0073e9SAndroid Build Coastguard Worker y = x + c.count 258*da0073e9SAndroid Build Coastguard Worker if c.count > 0: 259*da0073e9SAndroid Build Coastguard Worker return y 260*da0073e9SAndroid Build Coastguard Worker 261*da0073e9SAndroid Build Coastguard Worker x = y + y 262*da0073e9SAndroid Build Coastguard Worker return x 263*da0073e9SAndroid Build Coastguard Worker 264*da0073e9SAndroid Build Coastguard Worker def test_break(x: torch.Tensor, c: Context, l: List[int]) -> torch.Tensor: 265*da0073e9SAndroid Build Coastguard Worker """ 266*da0073e9SAndroid Build Coastguard Worker Test that breaking early from inside a with-statement works 267*da0073e9SAndroid Build Coastguard Worker as expected. 268*da0073e9SAndroid Build Coastguard Worker """ 269*da0073e9SAndroid Build Coastguard Worker with c: 270*da0073e9SAndroid Build Coastguard Worker for a in l: 271*da0073e9SAndroid Build Coastguard Worker if a == 0: 272*da0073e9SAndroid Build Coastguard Worker break 273*da0073e9SAndroid Build Coastguard Worker x += a * c.count 274*da0073e9SAndroid Build Coastguard Worker 275*da0073e9SAndroid Build Coastguard Worker return x 276*da0073e9SAndroid Build Coastguard Worker 277*da0073e9SAndroid Build Coastguard Worker def test_continue(x: torch.Tensor, c: Context, l: List[int]) -> torch.Tensor: 278*da0073e9SAndroid Build Coastguard Worker """ 279*da0073e9SAndroid Build Coastguard Worker Test that using continue inside a with-statement works 280*da0073e9SAndroid Build Coastguard Worker as expected. 281*da0073e9SAndroid Build Coastguard Worker """ 282*da0073e9SAndroid Build Coastguard Worker with c: 283*da0073e9SAndroid Build Coastguard Worker for a in l: 284*da0073e9SAndroid Build Coastguard Worker if a == 0: 285*da0073e9SAndroid Build Coastguard Worker continue 286*da0073e9SAndroid Build Coastguard Worker x += a * c.count 287*da0073e9SAndroid Build Coastguard Worker 288*da0073e9SAndroid Build Coastguard Worker return x 289*da0073e9SAndroid Build Coastguard Worker 290*da0073e9SAndroid Build Coastguard Worker def test_serial(x: torch.Tensor) -> torch.Tensor: 291*da0073e9SAndroid Build Coastguard Worker """ 292*da0073e9SAndroid Build Coastguard Worker Test two with-statements in a row. 293*da0073e9SAndroid Build Coastguard Worker """ 294*da0073e9SAndroid Build Coastguard Worker c = Context(1) 295*da0073e9SAndroid Build Coastguard Worker 296*da0073e9SAndroid Build Coastguard Worker with c: 297*da0073e9SAndroid Build Coastguard Worker y = x + c.count 298*da0073e9SAndroid Build Coastguard Worker 299*da0073e9SAndroid Build Coastguard Worker with c: 300*da0073e9SAndroid Build Coastguard Worker y *= c.count 301*da0073e9SAndroid Build Coastguard Worker 302*da0073e9SAndroid Build Coastguard Worker return y 303*da0073e9SAndroid Build Coastguard Worker 304*da0073e9SAndroid Build Coastguard Worker def test_nested(x: torch.Tensor) -> torch.Tensor: 305*da0073e9SAndroid Build Coastguard Worker """ 306*da0073e9SAndroid Build Coastguard Worker Test nested with-statements. 307*da0073e9SAndroid Build Coastguard Worker """ 308*da0073e9SAndroid Build Coastguard Worker c = Context(1) 309*da0073e9SAndroid Build Coastguard Worker 310*da0073e9SAndroid Build Coastguard Worker with c: 311*da0073e9SAndroid Build Coastguard Worker with c: 312*da0073e9SAndroid Build Coastguard Worker y = x + c.count 313*da0073e9SAndroid Build Coastguard Worker 314*da0073e9SAndroid Build Coastguard Worker y *= c.count 315*da0073e9SAndroid Build Coastguard Worker 316*da0073e9SAndroid Build Coastguard Worker return y 317*da0073e9SAndroid Build Coastguard Worker 318*da0073e9SAndroid Build Coastguard Worker def test_combined(x: torch.Tensor) -> torch.Tensor: 319*da0073e9SAndroid Build Coastguard Worker """ 320*da0073e9SAndroid Build Coastguard Worker Test a with-statement with multiple with items. 321*da0073e9SAndroid Build Coastguard Worker """ 322*da0073e9SAndroid Build Coastguard Worker c = Context(1) 323*da0073e9SAndroid Build Coastguard Worker d = Context(2) 324*da0073e9SAndroid Build Coastguard Worker 325*da0073e9SAndroid Build Coastguard Worker with c, d: 326*da0073e9SAndroid Build Coastguard Worker y = x + (c.count + d.count) 327*da0073e9SAndroid Build Coastguard Worker 328*da0073e9SAndroid Build Coastguard Worker return y 329*da0073e9SAndroid Build Coastguard Worker 330*da0073e9SAndroid Build Coastguard Worker test_input = torch.randn(2, 2) 331*da0073e9SAndroid Build Coastguard Worker test_context = Context(2) 332*da0073e9SAndroid Build Coastguard Worker test_list = [2, 0, 1, 3, 0, 2] 333*da0073e9SAndroid Build Coastguard Worker 334*da0073e9SAndroid Build Coastguard Worker self.checkScript(test_basic, (test_input,)) 335*da0073e9SAndroid Build Coastguard Worker self.checkScript(test_pass, (test_input,)) 336*da0073e9SAndroid Build Coastguard Worker self.checkScript(test_early_return, (test_input, test_context)) 337*da0073e9SAndroid Build Coastguard Worker self.checkScript(test_break, (test_input, test_context, test_list)) 338*da0073e9SAndroid Build Coastguard Worker self.checkScript(test_continue, (test_input, test_context, test_list)) 339*da0073e9SAndroid Build Coastguard Worker self.assertEqual(test_context.count, 2) 340*da0073e9SAndroid Build Coastguard Worker self.checkScript(test_serial, (test_input,)) 341*da0073e9SAndroid Build Coastguard Worker self.checkScript(test_nested, (test_input,)) 342*da0073e9SAndroid Build Coastguard Worker self.checkScript(test_combined, (test_input,)) 343*da0073e9SAndroid Build Coastguard Worker 344*da0073e9SAndroid Build Coastguard Worker def test_with_exceptions(self): 345*da0073e9SAndroid Build Coastguard Worker """ 346*da0073e9SAndroid Build Coastguard Worker Check that exceptions thrown in the bodies of with-statements are 347*da0073e9SAndroid Build Coastguard Worker handled correctly. 348*da0073e9SAndroid Build Coastguard Worker """ 349*da0073e9SAndroid Build Coastguard Worker 350*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 351*da0073e9SAndroid Build Coastguard Worker class Context: 352*da0073e9SAndroid Build Coastguard Worker """ 353*da0073e9SAndroid Build Coastguard Worker This class implements a basic context manager interface for use in 354*da0073e9SAndroid Build Coastguard Worker the unit tests. Unlike Context, the stateful part of this class 355*da0073e9SAndroid Build Coastguard Worker is a Tensor that is mutated in-place so that modifications made in the 356*da0073e9SAndroid Build Coastguard Worker JIT interpreter are visible outside of it. 357*da0073e9SAndroid Build Coastguard Worker """ 358*da0073e9SAndroid Build Coastguard Worker 359*da0073e9SAndroid Build Coastguard Worker def __init__(self, start: int): 360*da0073e9SAndroid Build Coastguard Worker self.count = torch.tensor([start], dtype=torch.double) 361*da0073e9SAndroid Build Coastguard Worker 362*da0073e9SAndroid Build Coastguard Worker def __enter__(self): 363*da0073e9SAndroid Build Coastguard Worker self.count.add_(0.3) 364*da0073e9SAndroid Build Coastguard Worker return self.count 365*da0073e9SAndroid Build Coastguard Worker 366*da0073e9SAndroid Build Coastguard Worker def __exit__(self, type: Any, value: Any, tb: Any): 367*da0073e9SAndroid Build Coastguard Worker self.count.sub_(0.3) 368*da0073e9SAndroid Build Coastguard Worker 369*da0073e9SAndroid Build Coastguard Worker make_global(Context) 370*da0073e9SAndroid Build Coastguard Worker 371*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 372*da0073e9SAndroid Build Coastguard Worker def method_that_raises() -> torch.Tensor: 373*da0073e9SAndroid Build Coastguard Worker raise Exception("raised exception") # noqa: TRY002 374*da0073e9SAndroid Build Coastguard Worker 375*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 376*da0073e9SAndroid Build Coastguard Worker def test_exception(x: torch.Tensor, c: Context) -> torch.Tensor: 377*da0073e9SAndroid Build Coastguard Worker """ 378*da0073e9SAndroid Build Coastguard Worker Test the case in which an exception is thrown while executing the body of a with-statement. 379*da0073e9SAndroid Build Coastguard Worker """ 380*da0073e9SAndroid Build Coastguard Worker with c as _: 381*da0073e9SAndroid Build Coastguard Worker x += method_that_raises() 382*da0073e9SAndroid Build Coastguard Worker 383*da0073e9SAndroid Build Coastguard Worker return x 384*da0073e9SAndroid Build Coastguard Worker 385*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 386*da0073e9SAndroid Build Coastguard Worker def test_exception_nested(x: torch.Tensor, c: Context) -> torch.Tensor: 387*da0073e9SAndroid Build Coastguard Worker """ 388*da0073e9SAndroid Build Coastguard Worker Test the case in which an exception is thrown while executing the body of a nested with-statement. 389*da0073e9SAndroid Build Coastguard Worker """ 390*da0073e9SAndroid Build Coastguard Worker with c as _: 391*da0073e9SAndroid Build Coastguard Worker with c as _: 392*da0073e9SAndroid Build Coastguard Worker x += method_that_raises() 393*da0073e9SAndroid Build Coastguard Worker 394*da0073e9SAndroid Build Coastguard Worker return x 395*da0073e9SAndroid Build Coastguard Worker 396*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 397*da0073e9SAndroid Build Coastguard Worker def with_that_raises(c: Context) -> torch.Tensor: 398*da0073e9SAndroid Build Coastguard Worker a = torch.tensor([1]) 399*da0073e9SAndroid Build Coastguard Worker 400*da0073e9SAndroid Build Coastguard Worker with c as _: 401*da0073e9SAndroid Build Coastguard Worker a += method_that_raises() 402*da0073e9SAndroid Build Coastguard Worker 403*da0073e9SAndroid Build Coastguard Worker return a 404*da0073e9SAndroid Build Coastguard Worker 405*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 406*da0073e9SAndroid Build Coastguard Worker def test_exception_fn_call(x: torch.Tensor, c: Context) -> torch.Tensor: 407*da0073e9SAndroid Build Coastguard Worker """ 408*da0073e9SAndroid Build Coastguard Worker Test the case in which an exception is thrown while there are active with-statements in two different 409*da0073e9SAndroid Build Coastguard Worker frames. 410*da0073e9SAndroid Build Coastguard Worker """ 411*da0073e9SAndroid Build Coastguard Worker with c as _: 412*da0073e9SAndroid Build Coastguard Worker x += with_that_raises(c) 413*da0073e9SAndroid Build Coastguard Worker 414*da0073e9SAndroid Build Coastguard Worker return x 415*da0073e9SAndroid Build Coastguard Worker 416*da0073e9SAndroid Build Coastguard Worker c = Context(1) 417*da0073e9SAndroid Build Coastguard Worker 418*da0073e9SAndroid Build Coastguard Worker # checkScript and checkScriptRaisesRegex cannot be used because the string frontend will 419*da0073e9SAndroid Build Coastguard Worker # not compile class types (of which Context, the context manager being used for this test 420*da0073e9SAndroid Build Coastguard Worker # is one). 421*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegexWithHighlight( 422*da0073e9SAndroid Build Coastguard Worker Exception, r"raised exception", 'raise Exception("raised exception' 423*da0073e9SAndroid Build Coastguard Worker ): 424*da0073e9SAndroid Build Coastguard Worker test_exception(torch.randn(2), c) 425*da0073e9SAndroid Build Coastguard Worker self.assertEqual(c.count, 1) 426*da0073e9SAndroid Build Coastguard Worker 427*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegexWithHighlight( 428*da0073e9SAndroid Build Coastguard Worker Exception, r"raised exception", 'raise Exception("raised exception' 429*da0073e9SAndroid Build Coastguard Worker ): 430*da0073e9SAndroid Build Coastguard Worker test_exception_nested(torch.randn(2), c) 431*da0073e9SAndroid Build Coastguard Worker self.assertEqual(c.count, 1) 432*da0073e9SAndroid Build Coastguard Worker 433*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegexWithHighlight( 434*da0073e9SAndroid Build Coastguard Worker Exception, r"raised exception", 'raise Exception("raised exception' 435*da0073e9SAndroid Build Coastguard Worker ): 436*da0073e9SAndroid Build Coastguard Worker test_exception_fn_call(torch.randn(2), c) 437*da0073e9SAndroid Build Coastguard Worker self.assertEqual(c.count, 1) 438*da0073e9SAndroid Build Coastguard Worker 439*da0073e9SAndroid Build Coastguard Worker def test_with_errors(self): 440*da0073e9SAndroid Build Coastguard Worker """ 441*da0073e9SAndroid Build Coastguard Worker Check that errors related to with-statements are detected and reported correctly. 442*da0073e9SAndroid Build Coastguard Worker """ 443*da0073e9SAndroid Build Coastguard Worker 444*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 445*da0073e9SAndroid Build Coastguard Worker class NoEnterNoExit: 446*da0073e9SAndroid Build Coastguard Worker """ 447*da0073e9SAndroid Build Coastguard Worker This class is missing __enter__ and __exit__ methods. 448*da0073e9SAndroid Build Coastguard Worker """ 449*da0073e9SAndroid Build Coastguard Worker 450*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 451*da0073e9SAndroid Build Coastguard Worker self.count = 1 452*da0073e9SAndroid Build Coastguard Worker 453*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 454*da0073e9SAndroid Build Coastguard Worker class BadEnter: 455*da0073e9SAndroid Build Coastguard Worker """ 456*da0073e9SAndroid Build Coastguard Worker This class has an __enter__ method with an incorrect signature. 457*da0073e9SAndroid Build Coastguard Worker """ 458*da0073e9SAndroid Build Coastguard Worker 459*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 460*da0073e9SAndroid Build Coastguard Worker self.count = 1 461*da0073e9SAndroid Build Coastguard Worker 462*da0073e9SAndroid Build Coastguard Worker def __enter__(self, incr: int): # noqa: PLE0302 463*da0073e9SAndroid Build Coastguard Worker self.count += incr 464*da0073e9SAndroid Build Coastguard Worker 465*da0073e9SAndroid Build Coastguard Worker def __exit__(self, type: Any, value: Any, tb: Any): 466*da0073e9SAndroid Build Coastguard Worker pass 467*da0073e9SAndroid Build Coastguard Worker 468*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 469*da0073e9SAndroid Build Coastguard Worker class BadExit: 470*da0073e9SAndroid Build Coastguard Worker """ 471*da0073e9SAndroid Build Coastguard Worker This class has an __exit__ method with an incorrect signature. 472*da0073e9SAndroid Build Coastguard Worker """ 473*da0073e9SAndroid Build Coastguard Worker 474*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 475*da0073e9SAndroid Build Coastguard Worker self.count = 1 476*da0073e9SAndroid Build Coastguard Worker 477*da0073e9SAndroid Build Coastguard Worker def __enter__(self): 478*da0073e9SAndroid Build Coastguard Worker self.count += 1 479*da0073e9SAndroid Build Coastguard Worker 480*da0073e9SAndroid Build Coastguard Worker def __exit__(self, type: Any, value: Any): # noqa: PLE0302 481*da0073e9SAndroid Build Coastguard Worker pass 482*da0073e9SAndroid Build Coastguard Worker 483*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 484*da0073e9SAndroid Build Coastguard Worker class ExitIncorrectTypes: 485*da0073e9SAndroid Build Coastguard Worker """ 486*da0073e9SAndroid Build Coastguard Worker This class has an __exit__ method with unsupported argument types. 487*da0073e9SAndroid Build Coastguard Worker """ 488*da0073e9SAndroid Build Coastguard Worker 489*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 490*da0073e9SAndroid Build Coastguard Worker self.count = 1 491*da0073e9SAndroid Build Coastguard Worker 492*da0073e9SAndroid Build Coastguard Worker def __enter__(self): 493*da0073e9SAndroid Build Coastguard Worker self.count += 1 494*da0073e9SAndroid Build Coastguard Worker 495*da0073e9SAndroid Build Coastguard Worker def __exit__(self, type: Any, value: int, tb: int): 496*da0073e9SAndroid Build Coastguard Worker pass 497*da0073e9SAndroid Build Coastguard Worker 498*da0073e9SAndroid Build Coastguard Worker def test_no_enter_no_exit(x: torch.Tensor, cm: NoEnterNoExit) -> torch.Tensor: 499*da0073e9SAndroid Build Coastguard Worker with cm as _: 500*da0073e9SAndroid Build Coastguard Worker pass 501*da0073e9SAndroid Build Coastguard Worker 502*da0073e9SAndroid Build Coastguard Worker return x 503*da0073e9SAndroid Build Coastguard Worker 504*da0073e9SAndroid Build Coastguard Worker def test_bad_enter(x: torch.Tensor, cm: BadEnter) -> torch.Tensor: 505*da0073e9SAndroid Build Coastguard Worker with cm as _: 506*da0073e9SAndroid Build Coastguard Worker pass 507*da0073e9SAndroid Build Coastguard Worker 508*da0073e9SAndroid Build Coastguard Worker return x 509*da0073e9SAndroid Build Coastguard Worker 510*da0073e9SAndroid Build Coastguard Worker def test_bad_exit(x: torch.Tensor, cm: BadExit) -> torch.Tensor: 511*da0073e9SAndroid Build Coastguard Worker with cm as _: 512*da0073e9SAndroid Build Coastguard Worker pass 513*da0073e9SAndroid Build Coastguard Worker 514*da0073e9SAndroid Build Coastguard Worker return x 515*da0073e9SAndroid Build Coastguard Worker 516*da0073e9SAndroid Build Coastguard Worker def test_exit_incorrect_types( 517*da0073e9SAndroid Build Coastguard Worker x: torch.Tensor, cm: ExitIncorrectTypes 518*da0073e9SAndroid Build Coastguard Worker ) -> torch.Tensor: 519*da0073e9SAndroid Build Coastguard Worker with cm as _: 520*da0073e9SAndroid Build Coastguard Worker pass 521*da0073e9SAndroid Build Coastguard Worker 522*da0073e9SAndroid Build Coastguard Worker return x 523*da0073e9SAndroid Build Coastguard Worker 524*da0073e9SAndroid Build Coastguard Worker def test_enter_without_object(): 525*da0073e9SAndroid Build Coastguard Worker with "not_object" as obj: 526*da0073e9SAndroid Build Coastguard Worker pass 527*da0073e9SAndroid Build Coastguard Worker 528*da0073e9SAndroid Build Coastguard Worker test_tensor = torch.randn(5, dtype=torch.double) 529*da0073e9SAndroid Build Coastguard Worker 530*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegexWithHighlight( 531*da0073e9SAndroid Build Coastguard Worker RuntimeError, r"does not define __enter__ and __exit__ methods", "cm" 532*da0073e9SAndroid Build Coastguard Worker ): 533*da0073e9SAndroid Build Coastguard Worker self.checkScript(test_no_enter_no_exit, (test_tensor, NoEnterNoExit())) 534*da0073e9SAndroid Build Coastguard Worker 535*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegexWithHighlight( 536*da0073e9SAndroid Build Coastguard Worker RuntimeError, 537*da0073e9SAndroid Build Coastguard Worker r"__enter__ must have only one argument and one return value", 538*da0073e9SAndroid Build Coastguard Worker "cm", 539*da0073e9SAndroid Build Coastguard Worker ): 540*da0073e9SAndroid Build Coastguard Worker self.checkScript(test_bad_enter, (test_tensor, BadEnter())) 541*da0073e9SAndroid Build Coastguard Worker 542*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegexWithHighlight( 543*da0073e9SAndroid Build Coastguard Worker RuntimeError, r"__exit__ must have four arguments", "cm" 544*da0073e9SAndroid Build Coastguard Worker ): 545*da0073e9SAndroid Build Coastguard Worker self.checkScript(test_bad_exit, (test_tensor, BadExit())) 546*da0073e9SAndroid Build Coastguard Worker 547*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegexWithHighlight( 548*da0073e9SAndroid Build Coastguard Worker RuntimeError, r"argument 2 of __exit__ must have Any type", "cm" 549*da0073e9SAndroid Build Coastguard Worker ): 550*da0073e9SAndroid Build Coastguard Worker self.checkScript( 551*da0073e9SAndroid Build Coastguard Worker test_exit_incorrect_types, (test_tensor, ExitIncorrectTypes()) 552*da0073e9SAndroid Build Coastguard Worker ) 553*da0073e9SAndroid Build Coastguard Worker 554*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegexWithHighlight( 555*da0073e9SAndroid Build Coastguard Worker RuntimeError, r"must return an object", '"not_object"' 556*da0073e9SAndroid Build Coastguard Worker ): 557*da0073e9SAndroid Build Coastguard Worker self.checkScript(test_enter_without_object, ()) 558*da0073e9SAndroid Build Coastguard Worker 559*da0073e9SAndroid Build Coastguard Worker def test_with_no_grad(self): 560*da0073e9SAndroid Build Coastguard Worker """ 561*da0073e9SAndroid Build Coastguard Worker Check that torch.no_grad() works. Most of these are adapted from 562*da0073e9SAndroid Build Coastguard Worker corresponding tests for eager-mode no_grad. 563*da0073e9SAndroid Build Coastguard Worker """ 564*da0073e9SAndroid Build Coastguard Worker 565*da0073e9SAndroid Build Coastguard Worker # Basic no_grad test. 566*da0073e9SAndroid Build Coastguard Worker def test_no_grad(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 567*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 568*da0073e9SAndroid Build Coastguard Worker w = x + y 569*da0073e9SAndroid Build Coastguard Worker 570*da0073e9SAndroid Build Coastguard Worker return w 571*da0073e9SAndroid Build Coastguard Worker 572*da0073e9SAndroid Build Coastguard Worker s = torch.jit.script(test_no_grad) 573*da0073e9SAndroid Build Coastguard Worker x = torch.ones(5, 5, requires_grad=True) 574*da0073e9SAndroid Build Coastguard Worker y = torch.ones(5, 5) * 4 575*da0073e9SAndroid Build Coastguard Worker w = s(x, y) 576*da0073e9SAndroid Build Coastguard Worker 577*da0073e9SAndroid Build Coastguard Worker self.assertFalse(w.requires_grad) 578*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: w.backward(torch.ones(5, 5))) 579*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(w.grad_fn) 580*da0073e9SAndroid Build Coastguard Worker 581*da0073e9SAndroid Build Coastguard Worker # Test assignment of a grad-less Tensor to a Tensor with gradients 582*da0073e9SAndroid Build Coastguard Worker # in a no_grad block. 583*da0073e9SAndroid Build Coastguard Worker def test_no_grad_assignment(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 584*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 585*da0073e9SAndroid Build Coastguard Worker x[0] = y 586*da0073e9SAndroid Build Coastguard Worker 587*da0073e9SAndroid Build Coastguard Worker return x 588*da0073e9SAndroid Build Coastguard Worker 589*da0073e9SAndroid Build Coastguard Worker s = torch.jit.script(test_no_grad_assignment) 590*da0073e9SAndroid Build Coastguard Worker z = torch.randn(5) 591*da0073e9SAndroid Build Coastguard Worker w = s(x, z) 592*da0073e9SAndroid Build Coastguard Worker self.assertTrue(w.requires_grad) 593*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(w.grad_fn) 594*da0073e9SAndroid Build Coastguard Worker 595*da0073e9SAndroid Build Coastguard Worker # Check that @torch.jit.ignored functions respect no_grad when it is 596*da0073e9SAndroid Build Coastguard Worker # called in JIT mode. 597*da0073e9SAndroid Build Coastguard Worker class NoGradModule(torch.nn.Module): 598*da0073e9SAndroid Build Coastguard Worker @torch.jit.ignore 599*da0073e9SAndroid Build Coastguard Worker def adder(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 600*da0073e9SAndroid Build Coastguard Worker w = x + y 601*da0073e9SAndroid Build Coastguard Worker return w 602*da0073e9SAndroid Build Coastguard Worker 603*da0073e9SAndroid Build Coastguard Worker def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 604*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 605*da0073e9SAndroid Build Coastguard Worker w = self.adder(x, y) 606*da0073e9SAndroid Build Coastguard Worker 607*da0073e9SAndroid Build Coastguard Worker return w 608*da0073e9SAndroid Build Coastguard Worker 609*da0073e9SAndroid Build Coastguard Worker s = torch.jit.script(NoGradModule()) 610*da0073e9SAndroid Build Coastguard Worker w = s(x, y) 611*da0073e9SAndroid Build Coastguard Worker 612*da0073e9SAndroid Build Coastguard Worker self.assertFalse(w.requires_grad) 613*da0073e9SAndroid Build Coastguard Worker 614*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("Torchdynamo cannot correctly handle profiler.profile calls") 615*da0073e9SAndroid Build Coastguard Worker def test_with_record_function(self): 616*da0073e9SAndroid Build Coastguard Worker """ 617*da0073e9SAndroid Build Coastguard Worker Check that torch.autograd.profiler.record_function context manager is 618*da0073e9SAndroid Build Coastguard Worker torchscriptable. 619*da0073e9SAndroid Build Coastguard Worker """ 620*da0073e9SAndroid Build Coastguard Worker 621*da0073e9SAndroid Build Coastguard Worker def with_rf(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 622*da0073e9SAndroid Build Coastguard Worker with torch.autograd.profiler.record_function("foo"): 623*da0073e9SAndroid Build Coastguard Worker # Nested record_function. 624*da0073e9SAndroid Build Coastguard Worker with torch.autograd.profiler.record_function("nested"): 625*da0073e9SAndroid Build Coastguard Worker a = x + y 626*da0073e9SAndroid Build Coastguard Worker return a 627*da0073e9SAndroid Build Coastguard Worker 628*da0073e9SAndroid Build Coastguard Worker scripted = torch.jit.script(with_rf) 629*da0073e9SAndroid Build Coastguard Worker x, y = torch.ones(2), torch.ones(2) 630*da0073e9SAndroid Build Coastguard Worker with torch.autograd.profiler.profile() as p: 631*da0073e9SAndroid Build Coastguard Worker scripted(x, y) 632*da0073e9SAndroid Build Coastguard Worker 633*da0073e9SAndroid Build Coastguard Worker # Need to call below to populate CPU children. 634*da0073e9SAndroid Build Coastguard Worker p.key_averages() 635*da0073e9SAndroid Build Coastguard Worker function_events = p.function_events 636*da0073e9SAndroid Build Coastguard Worker # Event with name "foo" should be recorded. 637*da0073e9SAndroid Build Coastguard Worker rf_events = [evt for evt in function_events if evt.name == "foo"] 638*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(rf_events), 1) 639*da0073e9SAndroid Build Coastguard Worker rf_event = rf_events[0] 640*da0073e9SAndroid Build Coastguard Worker child_events = rf_event.cpu_children 641*da0073e9SAndroid Build Coastguard Worker # Ensure we find nested record_function event 642*da0073e9SAndroid Build Coastguard Worker self.assertTrue("nested" in (child.name for child in child_events)) 643*da0073e9SAndroid Build Coastguard Worker nested_function_event = [ 644*da0073e9SAndroid Build Coastguard Worker evt for evt in function_events if evt.name == "nested" 645*da0073e9SAndroid Build Coastguard Worker ][0] 646*da0073e9SAndroid Build Coastguard Worker # Nested record function should have child "aten::add" 647*da0073e9SAndroid Build Coastguard Worker nested_child_events = nested_function_event.cpu_children 648*da0073e9SAndroid Build Coastguard Worker self.assertTrue("aten::add" in (child.name for child in nested_child_events)) 649