1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["oncall: jit"] 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerimport io 4*da0073e9SAndroid Build Coastguard Workerfrom typing import List, Optional, Tuple 5*da0073e9SAndroid Build Coastguard Worker 6*da0073e9SAndroid Build Coastguard Workerimport torch 7*da0073e9SAndroid Build Coastguard Workerfrom torch import Tensor 8*da0073e9SAndroid Build Coastguard Workerfrom torch._awaits import _Await as Await 9*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.jit_utils import JitTestCase, make_global 10*da0073e9SAndroid Build Coastguard Worker 11*da0073e9SAndroid Build Coastguard Worker 12*da0073e9SAndroid Build Coastguard Workerclass TestAwait(JitTestCase): 13*da0073e9SAndroid Build Coastguard Worker def test_await_python(self): 14*da0073e9SAndroid Build Coastguard Worker def foo(x: int) -> int: 15*da0073e9SAndroid Build Coastguard Worker return x + 13 16*da0073e9SAndroid Build Coastguard Worker 17*da0073e9SAndroid Build Coastguard Worker aw: Await[int] = torch.jit._awaitable(foo, 13) 18*da0073e9SAndroid Build Coastguard Worker self.assertTrue(aw.fn()(*aw.args()) == torch.jit._awaitable_wait(aw)) 19*da0073e9SAndroid Build Coastguard Worker nw = torch.jit._awaitable_nowait(33) 20*da0073e9SAndroid Build Coastguard Worker self.assertTrue(nw.is_nowait()) 21*da0073e9SAndroid Build Coastguard Worker self.assertTrue(nw.args() == (33,)) 22*da0073e9SAndroid Build Coastguard Worker 23*da0073e9SAndroid Build Coastguard Worker def test_await_type_python(self): 24*da0073e9SAndroid Build Coastguard Worker def foo() -> Tensor: 25*da0073e9SAndroid Build Coastguard Worker return torch.randn() 26*da0073e9SAndroid Build Coastguard Worker 27*da0073e9SAndroid Build Coastguard Worker awaits = torch.jit.annotate(List[Await[Tensor]], []) 28*da0073e9SAndroid Build Coastguard Worker awaits.append(torch.jit._awaitable(foo)) 29*da0073e9SAndroid Build Coastguard Worker 30*da0073e9SAndroid Build Coastguard Worker def test_script(self): 31*da0073e9SAndroid Build Coastguard Worker def delayed(z: int) -> int: 32*da0073e9SAndroid Build Coastguard Worker return z + 3 33*da0073e9SAndroid Build Coastguard Worker 34*da0073e9SAndroid Build Coastguard Worker def fn(x: Tensor): 35*da0073e9SAndroid Build Coastguard Worker aw: Await[int] = torch.jit._awaitable(delayed, 99) 36*da0073e9SAndroid Build Coastguard Worker a = torch.eye(2) 37*da0073e9SAndroid Build Coastguard Worker b = torch.jit._awaitable_wait(aw) 38*da0073e9SAndroid Build Coastguard Worker return a + b + x 39*da0073e9SAndroid Build Coastguard Worker 40*da0073e9SAndroid Build Coastguard Worker inp = torch.zeros(2) 41*da0073e9SAndroid Build Coastguard Worker 42*da0073e9SAndroid Build Coastguard Worker sm = torch.jit.script(fn) 43*da0073e9SAndroid Build Coastguard Worker out = fn(inp) 44*da0073e9SAndroid Build Coastguard Worker script_out = sm(inp) 45*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.allclose(torch.eye(2) + 102, script_out)) 46*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.allclose(script_out, out)) 47*da0073e9SAndroid Build Coastguard Worker 48*da0073e9SAndroid Build Coastguard Worker def test_nowait(self): 49*da0073e9SAndroid Build Coastguard Worker def fn(x: Tensor): 50*da0073e9SAndroid Build Coastguard Worker aw = torch.jit._awaitable_nowait(13) 51*da0073e9SAndroid Build Coastguard Worker a = torch.eye(2) 52*da0073e9SAndroid Build Coastguard Worker b = torch.jit._awaitable_wait(aw) 53*da0073e9SAndroid Build Coastguard Worker return a + b + x 54*da0073e9SAndroid Build Coastguard Worker 55*da0073e9SAndroid Build Coastguard Worker inp = torch.zeros(2) 56*da0073e9SAndroid Build Coastguard Worker 57*da0073e9SAndroid Build Coastguard Worker sm = torch.jit.script(fn) 58*da0073e9SAndroid Build Coastguard Worker out = fn(inp) 59*da0073e9SAndroid Build Coastguard Worker script_out = sm(inp) 60*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.allclose(torch.eye(2) + 13, script_out)) 61*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.allclose(script_out, out)) 62*da0073e9SAndroid Build Coastguard Worker 63*da0073e9SAndroid Build Coastguard Worker def test_nowait_class(self): 64*da0073e9SAndroid Build Coastguard Worker class C: 65*da0073e9SAndroid Build Coastguard Worker def __init__(self, a: Tensor, b: Tensor): 66*da0073e9SAndroid Build Coastguard Worker self._a = a 67*da0073e9SAndroid Build Coastguard Worker self._b = b 68*da0073e9SAndroid Build Coastguard Worker 69*da0073e9SAndroid Build Coastguard Worker def a(self) -> Tensor: 70*da0073e9SAndroid Build Coastguard Worker return self._a 71*da0073e9SAndroid Build Coastguard Worker 72*da0073e9SAndroid Build Coastguard Worker def fn(x: Tensor): 73*da0073e9SAndroid Build Coastguard Worker aw = torch.jit._awaitable_nowait(C(torch.zeros(2), torch.ones(2))) 74*da0073e9SAndroid Build Coastguard Worker _a = torch.eye(2) 75*da0073e9SAndroid Build Coastguard Worker c = torch.jit._awaitable_wait(aw) 76*da0073e9SAndroid Build Coastguard Worker return _a + c.a() + x 77*da0073e9SAndroid Build Coastguard Worker 78*da0073e9SAndroid Build Coastguard Worker make_global(C) 79*da0073e9SAndroid Build Coastguard Worker inp = torch.zeros(2) 80*da0073e9SAndroid Build Coastguard Worker 81*da0073e9SAndroid Build Coastguard Worker sm = torch.jit.script(fn) 82*da0073e9SAndroid Build Coastguard Worker out = fn(inp) 83*da0073e9SAndroid Build Coastguard Worker script_out = sm(inp) 84*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.allclose(torch.eye(2), script_out)) 85*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.allclose(script_out, out)) 86*da0073e9SAndroid Build Coastguard Worker 87*da0073e9SAndroid Build Coastguard Worker def test_await_class_arg(self): 88*da0073e9SAndroid Build Coastguard Worker class C: 89*da0073e9SAndroid Build Coastguard Worker def __init__(self, a: Tensor, b: Tensor): 90*da0073e9SAndroid Build Coastguard Worker self.__a = a 91*da0073e9SAndroid Build Coastguard Worker self.__b = b 92*da0073e9SAndroid Build Coastguard Worker 93*da0073e9SAndroid Build Coastguard Worker def a(self) -> Tensor: 94*da0073e9SAndroid Build Coastguard Worker return self.__a 95*da0073e9SAndroid Build Coastguard Worker 96*da0073e9SAndroid Build Coastguard Worker make_global(C) 97*da0073e9SAndroid Build Coastguard Worker 98*da0073e9SAndroid Build Coastguard Worker def delayed(c: C) -> Tensor: 99*da0073e9SAndroid Build Coastguard Worker return c.a() 100*da0073e9SAndroid Build Coastguard Worker 101*da0073e9SAndroid Build Coastguard Worker def fn(x: Tensor): 102*da0073e9SAndroid Build Coastguard Worker c = C(torch.zeros(2), torch.ones(2)) 103*da0073e9SAndroid Build Coastguard Worker aw = torch.jit._awaitable(delayed, c) 104*da0073e9SAndroid Build Coastguard Worker _a = torch.eye(2) 105*da0073e9SAndroid Build Coastguard Worker c2_t = torch.jit._awaitable_wait(aw) 106*da0073e9SAndroid Build Coastguard Worker return _a + c2_t + x 107*da0073e9SAndroid Build Coastguard Worker 108*da0073e9SAndroid Build Coastguard Worker inp = torch.zeros(2) 109*da0073e9SAndroid Build Coastguard Worker 110*da0073e9SAndroid Build Coastguard Worker sm = torch.jit.script(fn) 111*da0073e9SAndroid Build Coastguard Worker out = fn(inp) 112*da0073e9SAndroid Build Coastguard Worker script_out = sm(inp) 113*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.allclose(torch.eye(2), script_out)) 114*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.allclose(script_out, out)) 115*da0073e9SAndroid Build Coastguard Worker 116*da0073e9SAndroid Build Coastguard Worker def test_awaitable_to_await(self): 117*da0073e9SAndroid Build Coastguard Worker class C: 118*da0073e9SAndroid Build Coastguard Worker __slots__ = ["_a", "_b"] 119*da0073e9SAndroid Build Coastguard Worker 120*da0073e9SAndroid Build Coastguard Worker def __init__(self, a: Tensor, b: Tensor): 121*da0073e9SAndroid Build Coastguard Worker self._a = a 122*da0073e9SAndroid Build Coastguard Worker self._b = b 123*da0073e9SAndroid Build Coastguard Worker 124*da0073e9SAndroid Build Coastguard Worker make_global(C) 125*da0073e9SAndroid Build Coastguard Worker 126*da0073e9SAndroid Build Coastguard Worker # Can not stay in the class as Jit does not support Recursive annotations 127*da0073e9SAndroid Build Coastguard Worker # (self in wait_impl can not be annotated as C as C is not defined by this time) 128*da0073e9SAndroid Build Coastguard Worker def C_wait_impl(self: C): 129*da0073e9SAndroid Build Coastguard Worker return self._a + self._b 130*da0073e9SAndroid Build Coastguard Worker 131*da0073e9SAndroid Build Coastguard Worker def fn(x: Tensor): 132*da0073e9SAndroid Build Coastguard Worker aw = torch.jit._awaitable(C_wait_impl, C(torch.zeros(2), torch.ones(2))) 133*da0073e9SAndroid Build Coastguard Worker _a = torch.eye(2) 134*da0073e9SAndroid Build Coastguard Worker c_wait_impl_res = torch.jit._awaitable_wait(aw) 135*da0073e9SAndroid Build Coastguard Worker return _a + c_wait_impl_res + x 136*da0073e9SAndroid Build Coastguard Worker 137*da0073e9SAndroid Build Coastguard Worker inp = torch.ones(2) 138*da0073e9SAndroid Build Coastguard Worker 139*da0073e9SAndroid Build Coastguard Worker sm = torch.jit.script(fn) 140*da0073e9SAndroid Build Coastguard Worker out = fn(inp) 141*da0073e9SAndroid Build Coastguard Worker script_out = sm(inp) 142*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.allclose(torch.eye(2) + 2 * torch.ones(2), script_out)) 143*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.allclose(script_out, out)) 144*da0073e9SAndroid Build Coastguard Worker 145*da0073e9SAndroid Build Coastguard Worker def test_await_class_return(self): 146*da0073e9SAndroid Build Coastguard Worker class C: 147*da0073e9SAndroid Build Coastguard Worker __slots__ = ["a", "b"] 148*da0073e9SAndroid Build Coastguard Worker 149*da0073e9SAndroid Build Coastguard Worker def __init__(self, a: Tensor, b: Tensor): 150*da0073e9SAndroid Build Coastguard Worker self.a = a 151*da0073e9SAndroid Build Coastguard Worker self.b = b 152*da0073e9SAndroid Build Coastguard Worker 153*da0073e9SAndroid Build Coastguard Worker make_global(C) 154*da0073e9SAndroid Build Coastguard Worker 155*da0073e9SAndroid Build Coastguard Worker # Can not stay in the class as Jit does not support Recursive annotations 156*da0073e9SAndroid Build Coastguard Worker # (self in wait_impl can not be annotated as C as C is not defined by this time) 157*da0073e9SAndroid Build Coastguard Worker def C_wait_impl(self: C) -> C: 158*da0073e9SAndroid Build Coastguard Worker return C(self.a * 2, self.b * 3) 159*da0073e9SAndroid Build Coastguard Worker 160*da0073e9SAndroid Build Coastguard Worker def fn_arg_C(x: C) -> Tensor: 161*da0073e9SAndroid Build Coastguard Worker return x.a + x.b 162*da0073e9SAndroid Build Coastguard Worker 163*da0073e9SAndroid Build Coastguard Worker def fn(x: Tensor): 164*da0073e9SAndroid Build Coastguard Worker aw: Await[C] = torch.jit._awaitable(C_wait_impl, C(x, x)) 165*da0073e9SAndroid Build Coastguard Worker _a = torch.eye(2) 166*da0073e9SAndroid Build Coastguard Worker y = fn_arg_C(torch.jit._awaitable_wait(aw)) 167*da0073e9SAndroid Build Coastguard Worker return _a + y + x 168*da0073e9SAndroid Build Coastguard Worker 169*da0073e9SAndroid Build Coastguard Worker inp = torch.ones(2) 170*da0073e9SAndroid Build Coastguard Worker 171*da0073e9SAndroid Build Coastguard Worker sm = torch.jit.script(fn) 172*da0073e9SAndroid Build Coastguard Worker out = fn(inp) 173*da0073e9SAndroid Build Coastguard Worker script_out = sm(inp) 174*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.allclose(torch.eye(2) + 6 * torch.ones(2), script_out)) 175*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.allclose(script_out, out)) 176*da0073e9SAndroid Build Coastguard Worker self.assertGraphContainsExactly( 177*da0073e9SAndroid Build Coastguard Worker sm.graph, kind="prim::awaitable_wait", num_kind_nodes=1 178*da0073e9SAndroid Build Coastguard Worker ) 179*da0073e9SAndroid Build Coastguard Worker 180*da0073e9SAndroid Build Coastguard Worker def test_await_getattr_implicit_convertion(self): 181*da0073e9SAndroid Build Coastguard Worker class C: 182*da0073e9SAndroid Build Coastguard Worker def __init__(self, a: Tensor, b: Tensor): 183*da0073e9SAndroid Build Coastguard Worker self._a = a 184*da0073e9SAndroid Build Coastguard Worker self._b = b 185*da0073e9SAndroid Build Coastguard Worker 186*da0073e9SAndroid Build Coastguard Worker def b(self): 187*da0073e9SAndroid Build Coastguard Worker return self._b 188*da0073e9SAndroid Build Coastguard Worker 189*da0073e9SAndroid Build Coastguard Worker make_global(C) 190*da0073e9SAndroid Build Coastguard Worker 191*da0073e9SAndroid Build Coastguard Worker # Can not stay in the class as Jit does not support Recursive annotations 192*da0073e9SAndroid Build Coastguard Worker # (self in wait_impl can not be annotated as C as C is not defined by this time) 193*da0073e9SAndroid Build Coastguard Worker def C_wait_impl(self: C) -> C: 194*da0073e9SAndroid Build Coastguard Worker return C(self._a * 2, self._b * 3) 195*da0073e9SAndroid Build Coastguard Worker 196*da0073e9SAndroid Build Coastguard Worker def fn_arg_C(x: C) -> Tensor: 197*da0073e9SAndroid Build Coastguard Worker return x._a + x._b 198*da0073e9SAndroid Build Coastguard Worker 199*da0073e9SAndroid Build Coastguard Worker def fn(x: Tensor): 200*da0073e9SAndroid Build Coastguard Worker aw: Await[C] = torch.jit._awaitable(C_wait_impl, C(x, x)) 201*da0073e9SAndroid Build Coastguard Worker _a = torch.eye(2) 202*da0073e9SAndroid Build Coastguard Worker ai = aw._a 203*da0073e9SAndroid Build Coastguard Worker awb = aw.b() 204*da0073e9SAndroid Build Coastguard Worker c = C(2 * x, 2 * x) 205*da0073e9SAndroid Build Coastguard Worker return _a + ai + x + c._a + c.b() 206*da0073e9SAndroid Build Coastguard Worker 207*da0073e9SAndroid Build Coastguard Worker inp = torch.ones(2) 208*da0073e9SAndroid Build Coastguard Worker 209*da0073e9SAndroid Build Coastguard Worker sm = torch.jit.script(fn) 210*da0073e9SAndroid Build Coastguard Worker out = fn(inp) 211*da0073e9SAndroid Build Coastguard Worker script_out = sm(inp) 212*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.allclose(torch.eye(2) + 7 * torch.ones(2), script_out)) 213*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.allclose(script_out, out)) 214*da0073e9SAndroid Build Coastguard Worker self.assertGraphContainsExactly( 215*da0073e9SAndroid Build Coastguard Worker sm.graph, kind="prim::awaitable_wait", num_kind_nodes=2 216*da0073e9SAndroid Build Coastguard Worker ) 217*da0073e9SAndroid Build Coastguard Worker 218*da0073e9SAndroid Build Coastguard Worker def test_await_nested(self): 219*da0073e9SAndroid Build Coastguard Worker class C: 220*da0073e9SAndroid Build Coastguard Worker def __init__(self, a: Tensor, b: Tensor): 221*da0073e9SAndroid Build Coastguard Worker self.__a = a 222*da0073e9SAndroid Build Coastguard Worker self.__b = b 223*da0073e9SAndroid Build Coastguard Worker 224*da0073e9SAndroid Build Coastguard Worker def a(self) -> Tensor: 225*da0073e9SAndroid Build Coastguard Worker return self.__a 226*da0073e9SAndroid Build Coastguard Worker 227*da0073e9SAndroid Build Coastguard Worker make_global(C) 228*da0073e9SAndroid Build Coastguard Worker 229*da0073e9SAndroid Build Coastguard Worker def delayed(c: C) -> Await[Tensor]: 230*da0073e9SAndroid Build Coastguard Worker return torch.jit._awaitable_nowait(3 * c.a()) 231*da0073e9SAndroid Build Coastguard Worker 232*da0073e9SAndroid Build Coastguard Worker def fn(x: Tensor) -> Await[Await[Tensor]]: 233*da0073e9SAndroid Build Coastguard Worker return torch.jit._awaitable(delayed, C(2 * x, x)) 234*da0073e9SAndroid Build Coastguard Worker 235*da0073e9SAndroid Build Coastguard Worker def main(x: Tensor) -> Tensor: 236*da0073e9SAndroid Build Coastguard Worker awaw = fn(x) 237*da0073e9SAndroid Build Coastguard Worker return torch.jit._awaitable_wait(torch.jit._awaitable_wait(awaw)) 238*da0073e9SAndroid Build Coastguard Worker 239*da0073e9SAndroid Build Coastguard Worker inp = torch.eye(2) 240*da0073e9SAndroid Build Coastguard Worker 241*da0073e9SAndroid Build Coastguard Worker sm = torch.jit.script(main) 242*da0073e9SAndroid Build Coastguard Worker out = main(inp) 243*da0073e9SAndroid Build Coastguard Worker script_out = sm(inp) 244*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.allclose(6 * torch.eye(2), script_out)) 245*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.allclose(script_out, out)) 246*da0073e9SAndroid Build Coastguard Worker 247*da0073e9SAndroid Build Coastguard Worker def test_eager_await_non_scriptable(self): 248*da0073e9SAndroid Build Coastguard Worker # Tree type can not be compiled (Recursive type) 249*da0073e9SAndroid Build Coastguard Worker class Tree: 250*da0073e9SAndroid Build Coastguard Worker def __init__(self, v): 251*da0073e9SAndroid Build Coastguard Worker self.parent = torch.jit.annotate(Optional[Tree], None) 252*da0073e9SAndroid Build Coastguard Worker self.v = v 253*da0073e9SAndroid Build Coastguard Worker 254*da0073e9SAndroid Build Coastguard Worker make_global(Tree) 255*da0073e9SAndroid Build Coastguard Worker 256*da0073e9SAndroid Build Coastguard Worker def delayed(t: Tree): 257*da0073e9SAndroid Build Coastguard Worker t.v = t.v + 1 258*da0073e9SAndroid Build Coastguard Worker return t 259*da0073e9SAndroid Build Coastguard Worker 260*da0073e9SAndroid Build Coastguard Worker aw = torch.jit._awaitable(delayed, Tree(2)) 261*da0073e9SAndroid Build Coastguard Worker t = torch.jit._awaitable_wait(aw) 262*da0073e9SAndroid Build Coastguard Worker self.assertTrue(t.v == 3) 263*da0073e9SAndroid Build Coastguard Worker 264*da0073e9SAndroid Build Coastguard Worker def test_await_isinstance(self): 265*da0073e9SAndroid Build Coastguard Worker def delayed(x: Tensor) -> Tensor: 266*da0073e9SAndroid Build Coastguard Worker return 2 * (x + 1) 267*da0073e9SAndroid Build Coastguard Worker 268*da0073e9SAndroid Build Coastguard Worker def main(x: Tensor) -> Tensor: 269*da0073e9SAndroid Build Coastguard Worker aw = torch.jit._awaitable(delayed, x) 270*da0073e9SAndroid Build Coastguard Worker if torch.jit.is_scripting(): 271*da0073e9SAndroid Build Coastguard Worker assert isinstance(aw, torch.jit._Await) 272*da0073e9SAndroid Build Coastguard Worker return torch.jit._awaitable_wait(aw) 273*da0073e9SAndroid Build Coastguard Worker 274*da0073e9SAndroid Build Coastguard Worker inp = torch.eye(2) 275*da0073e9SAndroid Build Coastguard Worker 276*da0073e9SAndroid Build Coastguard Worker sm = torch.jit.script(main) 277*da0073e9SAndroid Build Coastguard Worker out = main(inp) 278*da0073e9SAndroid Build Coastguard Worker script_out = sm(inp) 279*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 280*da0073e9SAndroid Build Coastguard Worker torch.allclose(2 * torch.eye(2) + 2 * torch.ones(2), script_out) 281*da0073e9SAndroid Build Coastguard Worker ) 282*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.allclose(script_out, out)) 283*da0073e9SAndroid Build Coastguard Worker 284*da0073e9SAndroid Build Coastguard Worker def test_await_eager_lazy(self): 285*da0073e9SAndroid Build Coastguard Worker def delayed(x: Tensor) -> Tensor: 286*da0073e9SAndroid Build Coastguard Worker return 2 * (x + 1) 287*da0073e9SAndroid Build Coastguard Worker 288*da0073e9SAndroid Build Coastguard Worker t = torch.ones(2, dtype=torch.int64) 289*da0073e9SAndroid Build Coastguard Worker aw = torch.jit._awaitable(delayed, t) 290*da0073e9SAndroid Build Coastguard Worker self.assertTrue(isinstance(aw, torch._C._Await)) 291*da0073e9SAndroid Build Coastguard Worker self.assertTrue(t.dtype == aw.dtype) 292*da0073e9SAndroid Build Coastguard Worker 293*da0073e9SAndroid Build Coastguard Worker def test_await_out_of_interpreter(self): 294*da0073e9SAndroid Build Coastguard Worker def delayed(x: Tensor) -> Tensor: 295*da0073e9SAndroid Build Coastguard Worker return 2 * (x + 1) 296*da0073e9SAndroid Build Coastguard Worker 297*da0073e9SAndroid Build Coastguard Worker def main(x: Tensor) -> Await[Tensor]: 298*da0073e9SAndroid Build Coastguard Worker aw = torch.jit._awaitable(delayed, x) 299*da0073e9SAndroid Build Coastguard Worker return aw 300*da0073e9SAndroid Build Coastguard Worker 301*da0073e9SAndroid Build Coastguard Worker inp = torch.eye(2) 302*da0073e9SAndroid Build Coastguard Worker 303*da0073e9SAndroid Build Coastguard Worker sm = torch.jit.script(main) 304*da0073e9SAndroid Build Coastguard Worker out_aw = main(inp) 305*da0073e9SAndroid Build Coastguard Worker out = torch.jit._awaitable_wait(out_aw) 306*da0073e9SAndroid Build Coastguard Worker 307*da0073e9SAndroid Build Coastguard Worker script_out_aw = sm(inp) 308*da0073e9SAndroid Build Coastguard Worker script_out = torch.jit._awaitable_wait(script_out_aw) 309*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 310*da0073e9SAndroid Build Coastguard Worker torch.allclose(2 * torch.eye(2) + 2 * torch.ones(2), script_out) 311*da0073e9SAndroid Build Coastguard Worker ) 312*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.allclose(script_out, out)) 313*da0073e9SAndroid Build Coastguard Worker 314*da0073e9SAndroid Build Coastguard Worker def test_jit_trace(self): 315*da0073e9SAndroid Build Coastguard Worker def gap(x: Tensor): 316*da0073e9SAndroid Build Coastguard Worker return torch.relu(x) + torch.sin(x) 317*da0073e9SAndroid Build Coastguard Worker 318*da0073e9SAndroid Build Coastguard Worker def delayed(x: Tensor) -> Tensor: 319*da0073e9SAndroid Build Coastguard Worker return 2 * (torch.cos(x) + 1) 320*da0073e9SAndroid Build Coastguard Worker 321*da0073e9SAndroid Build Coastguard Worker def main(x: Tensor, y: Tensor) -> Tensor: 322*da0073e9SAndroid Build Coastguard Worker aw = torch.jit._awaitable(delayed, x) 323*da0073e9SAndroid Build Coastguard Worker z = gap(y) 324*da0073e9SAndroid Build Coastguard Worker k = torch.jit._awaitable_wait(aw) 325*da0073e9SAndroid Build Coastguard Worker return y + k 326*da0073e9SAndroid Build Coastguard Worker 327*da0073e9SAndroid Build Coastguard Worker inp = torch.randn(2) 328*da0073e9SAndroid Build Coastguard Worker tm = torch.jit.trace(main, (inp, inp)) 329*da0073e9SAndroid Build Coastguard Worker inp_check = torch.ones(2) 330*da0073e9SAndroid Build Coastguard Worker self.assertEqual(main(inp_check, inp_check), tm(inp_check, inp_check)) 331*da0073e9SAndroid Build Coastguard Worker 332*da0073e9SAndroid Build Coastguard Worker def test_await_multiout_save(self): 333*da0073e9SAndroid Build Coastguard Worker def gap(x: Tensor): 334*da0073e9SAndroid Build Coastguard Worker return torch.relu(x) + torch.sin(x) 335*da0073e9SAndroid Build Coastguard Worker 336*da0073e9SAndroid Build Coastguard Worker def delayed(x: Tensor) -> Tuple[Tensor, List[Tensor]]: 337*da0073e9SAndroid Build Coastguard Worker l = [x * i for i in range(5)] 338*da0073e9SAndroid Build Coastguard Worker return (100 * x, l) 339*da0073e9SAndroid Build Coastguard Worker 340*da0073e9SAndroid Build Coastguard Worker def main(x: Tensor) -> Tensor: 341*da0073e9SAndroid Build Coastguard Worker aw = torch.jit._awaitable(delayed, x) 342*da0073e9SAndroid Build Coastguard Worker z = gap(x) 343*da0073e9SAndroid Build Coastguard Worker (_, l) = torch.jit._awaitable_wait(aw) 344*da0073e9SAndroid Build Coastguard Worker return l[3] + z 345*da0073e9SAndroid Build Coastguard Worker 346*da0073e9SAndroid Build Coastguard Worker inp = torch.eye(2) 347*da0073e9SAndroid Build Coastguard Worker 348*da0073e9SAndroid Build Coastguard Worker sm = torch.jit.script(main) 349*da0073e9SAndroid Build Coastguard Worker out = main(inp) 350*da0073e9SAndroid Build Coastguard Worker script_out = sm(inp) 351*da0073e9SAndroid Build Coastguard Worker expected = 4.8415 * torch.eye(2) 352*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.allclose(expected, script_out)) 353*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.allclose(script_out, out)) 354*da0073e9SAndroid Build Coastguard Worker 355*da0073e9SAndroid Build Coastguard Worker iofile = io.BytesIO() 356*da0073e9SAndroid Build Coastguard Worker torch.jit.save(sm, iofile) 357*da0073e9SAndroid Build Coastguard Worker iofile.seek(0) 358*da0073e9SAndroid Build Coastguard Worker sm = torch.jit.load(iofile) 359*da0073e9SAndroid Build Coastguard Worker script_out_load = sm(inp) 360*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.allclose(expected, script_out_load)) 361*da0073e9SAndroid Build Coastguard Worker 362*da0073e9SAndroid Build Coastguard Worker def test_await_func_arg(self): 363*da0073e9SAndroid Build Coastguard Worker def gap(x: Tensor): 364*da0073e9SAndroid Build Coastguard Worker return torch.relu(x) + torch.sin(x) 365*da0073e9SAndroid Build Coastguard Worker 366*da0073e9SAndroid Build Coastguard Worker def delayed(x: Tensor) -> Tensor: 367*da0073e9SAndroid Build Coastguard Worker return -1 * x 368*da0073e9SAndroid Build Coastguard Worker 369*da0073e9SAndroid Build Coastguard Worker def fn(aw: Await[Tensor]) -> Tensor: 370*da0073e9SAndroid Build Coastguard Worker return 3 * torch.jit._awaitable_wait(aw) 371*da0073e9SAndroid Build Coastguard Worker 372*da0073e9SAndroid Build Coastguard Worker def main(x: Tensor) -> Tensor: 373*da0073e9SAndroid Build Coastguard Worker aw = torch.jit._awaitable(delayed, x) 374*da0073e9SAndroid Build Coastguard Worker z = gap(x) 375*da0073e9SAndroid Build Coastguard Worker y = fn(aw) 376*da0073e9SAndroid Build Coastguard Worker return y + x 377*da0073e9SAndroid Build Coastguard Worker 378*da0073e9SAndroid Build Coastguard Worker inp = torch.eye(2) 379*da0073e9SAndroid Build Coastguard Worker 380*da0073e9SAndroid Build Coastguard Worker sm = torch.jit.script(main) 381*da0073e9SAndroid Build Coastguard Worker out = main(inp) 382*da0073e9SAndroid Build Coastguard Worker script_out = sm(inp) 383*da0073e9SAndroid Build Coastguard Worker expected = -2 * torch.eye(2) 384*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.allclose(expected, script_out)) 385*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.allclose(script_out, out)) 386*da0073e9SAndroid Build Coastguard Worker 387*da0073e9SAndroid Build Coastguard Worker iofile = io.BytesIO() 388*da0073e9SAndroid Build Coastguard Worker torch.jit.save(sm, iofile) 389*da0073e9SAndroid Build Coastguard Worker iofile.seek(0) 390*da0073e9SAndroid Build Coastguard Worker sm = torch.jit.load(iofile) 391*da0073e9SAndroid Build Coastguard Worker script_out_load = sm(inp) 392*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.allclose(expected, script_out_load)) 393