xref: /aosp_15_r20/external/pytorch/test/jit/test_await.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["oncall: jit"]
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerimport io
4*da0073e9SAndroid Build Coastguard Workerfrom typing import List, Optional, Tuple
5*da0073e9SAndroid Build Coastguard Worker
6*da0073e9SAndroid Build Coastguard Workerimport torch
7*da0073e9SAndroid Build Coastguard Workerfrom torch import Tensor
8*da0073e9SAndroid Build Coastguard Workerfrom torch._awaits import _Await as Await
9*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.jit_utils import JitTestCase, make_global
10*da0073e9SAndroid Build Coastguard Worker
11*da0073e9SAndroid Build Coastguard Worker
12*da0073e9SAndroid Build Coastguard Workerclass TestAwait(JitTestCase):
13*da0073e9SAndroid Build Coastguard Worker    def test_await_python(self):
14*da0073e9SAndroid Build Coastguard Worker        def foo(x: int) -> int:
15*da0073e9SAndroid Build Coastguard Worker            return x + 13
16*da0073e9SAndroid Build Coastguard Worker
17*da0073e9SAndroid Build Coastguard Worker        aw: Await[int] = torch.jit._awaitable(foo, 13)
18*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(aw.fn()(*aw.args()) == torch.jit._awaitable_wait(aw))
19*da0073e9SAndroid Build Coastguard Worker        nw = torch.jit._awaitable_nowait(33)
20*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(nw.is_nowait())
21*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(nw.args() == (33,))
22*da0073e9SAndroid Build Coastguard Worker
23*da0073e9SAndroid Build Coastguard Worker    def test_await_type_python(self):
24*da0073e9SAndroid Build Coastguard Worker        def foo() -> Tensor:
25*da0073e9SAndroid Build Coastguard Worker            return torch.randn()
26*da0073e9SAndroid Build Coastguard Worker
27*da0073e9SAndroid Build Coastguard Worker        awaits = torch.jit.annotate(List[Await[Tensor]], [])
28*da0073e9SAndroid Build Coastguard Worker        awaits.append(torch.jit._awaitable(foo))
29*da0073e9SAndroid Build Coastguard Worker
30*da0073e9SAndroid Build Coastguard Worker    def test_script(self):
31*da0073e9SAndroid Build Coastguard Worker        def delayed(z: int) -> int:
32*da0073e9SAndroid Build Coastguard Worker            return z + 3
33*da0073e9SAndroid Build Coastguard Worker
34*da0073e9SAndroid Build Coastguard Worker        def fn(x: Tensor):
35*da0073e9SAndroid Build Coastguard Worker            aw: Await[int] = torch.jit._awaitable(delayed, 99)
36*da0073e9SAndroid Build Coastguard Worker            a = torch.eye(2)
37*da0073e9SAndroid Build Coastguard Worker            b = torch.jit._awaitable_wait(aw)
38*da0073e9SAndroid Build Coastguard Worker            return a + b + x
39*da0073e9SAndroid Build Coastguard Worker
40*da0073e9SAndroid Build Coastguard Worker        inp = torch.zeros(2)
41*da0073e9SAndroid Build Coastguard Worker
42*da0073e9SAndroid Build Coastguard Worker        sm = torch.jit.script(fn)
43*da0073e9SAndroid Build Coastguard Worker        out = fn(inp)
44*da0073e9SAndroid Build Coastguard Worker        script_out = sm(inp)
45*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.allclose(torch.eye(2) + 102, script_out))
46*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.allclose(script_out, out))
47*da0073e9SAndroid Build Coastguard Worker
48*da0073e9SAndroid Build Coastguard Worker    def test_nowait(self):
49*da0073e9SAndroid Build Coastguard Worker        def fn(x: Tensor):
50*da0073e9SAndroid Build Coastguard Worker            aw = torch.jit._awaitable_nowait(13)
51*da0073e9SAndroid Build Coastguard Worker            a = torch.eye(2)
52*da0073e9SAndroid Build Coastguard Worker            b = torch.jit._awaitable_wait(aw)
53*da0073e9SAndroid Build Coastguard Worker            return a + b + x
54*da0073e9SAndroid Build Coastguard Worker
55*da0073e9SAndroid Build Coastguard Worker        inp = torch.zeros(2)
56*da0073e9SAndroid Build Coastguard Worker
57*da0073e9SAndroid Build Coastguard Worker        sm = torch.jit.script(fn)
58*da0073e9SAndroid Build Coastguard Worker        out = fn(inp)
59*da0073e9SAndroid Build Coastguard Worker        script_out = sm(inp)
60*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.allclose(torch.eye(2) + 13, script_out))
61*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.allclose(script_out, out))
62*da0073e9SAndroid Build Coastguard Worker
63*da0073e9SAndroid Build Coastguard Worker    def test_nowait_class(self):
64*da0073e9SAndroid Build Coastguard Worker        class C:
65*da0073e9SAndroid Build Coastguard Worker            def __init__(self, a: Tensor, b: Tensor):
66*da0073e9SAndroid Build Coastguard Worker                self._a = a
67*da0073e9SAndroid Build Coastguard Worker                self._b = b
68*da0073e9SAndroid Build Coastguard Worker
69*da0073e9SAndroid Build Coastguard Worker            def a(self) -> Tensor:
70*da0073e9SAndroid Build Coastguard Worker                return self._a
71*da0073e9SAndroid Build Coastguard Worker
72*da0073e9SAndroid Build Coastguard Worker        def fn(x: Tensor):
73*da0073e9SAndroid Build Coastguard Worker            aw = torch.jit._awaitable_nowait(C(torch.zeros(2), torch.ones(2)))
74*da0073e9SAndroid Build Coastguard Worker            _a = torch.eye(2)
75*da0073e9SAndroid Build Coastguard Worker            c = torch.jit._awaitable_wait(aw)
76*da0073e9SAndroid Build Coastguard Worker            return _a + c.a() + x
77*da0073e9SAndroid Build Coastguard Worker
78*da0073e9SAndroid Build Coastguard Worker        make_global(C)
79*da0073e9SAndroid Build Coastguard Worker        inp = torch.zeros(2)
80*da0073e9SAndroid Build Coastguard Worker
81*da0073e9SAndroid Build Coastguard Worker        sm = torch.jit.script(fn)
82*da0073e9SAndroid Build Coastguard Worker        out = fn(inp)
83*da0073e9SAndroid Build Coastguard Worker        script_out = sm(inp)
84*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.allclose(torch.eye(2), script_out))
85*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.allclose(script_out, out))
86*da0073e9SAndroid Build Coastguard Worker
87*da0073e9SAndroid Build Coastguard Worker    def test_await_class_arg(self):
88*da0073e9SAndroid Build Coastguard Worker        class C:
89*da0073e9SAndroid Build Coastguard Worker            def __init__(self, a: Tensor, b: Tensor):
90*da0073e9SAndroid Build Coastguard Worker                self.__a = a
91*da0073e9SAndroid Build Coastguard Worker                self.__b = b
92*da0073e9SAndroid Build Coastguard Worker
93*da0073e9SAndroid Build Coastguard Worker            def a(self) -> Tensor:
94*da0073e9SAndroid Build Coastguard Worker                return self.__a
95*da0073e9SAndroid Build Coastguard Worker
96*da0073e9SAndroid Build Coastguard Worker        make_global(C)
97*da0073e9SAndroid Build Coastguard Worker
98*da0073e9SAndroid Build Coastguard Worker        def delayed(c: C) -> Tensor:
99*da0073e9SAndroid Build Coastguard Worker            return c.a()
100*da0073e9SAndroid Build Coastguard Worker
101*da0073e9SAndroid Build Coastguard Worker        def fn(x: Tensor):
102*da0073e9SAndroid Build Coastguard Worker            c = C(torch.zeros(2), torch.ones(2))
103*da0073e9SAndroid Build Coastguard Worker            aw = torch.jit._awaitable(delayed, c)
104*da0073e9SAndroid Build Coastguard Worker            _a = torch.eye(2)
105*da0073e9SAndroid Build Coastguard Worker            c2_t = torch.jit._awaitable_wait(aw)
106*da0073e9SAndroid Build Coastguard Worker            return _a + c2_t + x
107*da0073e9SAndroid Build Coastguard Worker
108*da0073e9SAndroid Build Coastguard Worker        inp = torch.zeros(2)
109*da0073e9SAndroid Build Coastguard Worker
110*da0073e9SAndroid Build Coastguard Worker        sm = torch.jit.script(fn)
111*da0073e9SAndroid Build Coastguard Worker        out = fn(inp)
112*da0073e9SAndroid Build Coastguard Worker        script_out = sm(inp)
113*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.allclose(torch.eye(2), script_out))
114*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.allclose(script_out, out))
115*da0073e9SAndroid Build Coastguard Worker
116*da0073e9SAndroid Build Coastguard Worker    def test_awaitable_to_await(self):
117*da0073e9SAndroid Build Coastguard Worker        class C:
118*da0073e9SAndroid Build Coastguard Worker            __slots__ = ["_a", "_b"]
119*da0073e9SAndroid Build Coastguard Worker
120*da0073e9SAndroid Build Coastguard Worker            def __init__(self, a: Tensor, b: Tensor):
121*da0073e9SAndroid Build Coastguard Worker                self._a = a
122*da0073e9SAndroid Build Coastguard Worker                self._b = b
123*da0073e9SAndroid Build Coastguard Worker
124*da0073e9SAndroid Build Coastguard Worker        make_global(C)
125*da0073e9SAndroid Build Coastguard Worker
126*da0073e9SAndroid Build Coastguard Worker        # Can not stay in the class as Jit does not support Recursive annotations
127*da0073e9SAndroid Build Coastguard Worker        # (self in wait_impl can not be annotated as C as C is not defined by this time)
128*da0073e9SAndroid Build Coastguard Worker        def C_wait_impl(self: C):
129*da0073e9SAndroid Build Coastguard Worker            return self._a + self._b
130*da0073e9SAndroid Build Coastguard Worker
131*da0073e9SAndroid Build Coastguard Worker        def fn(x: Tensor):
132*da0073e9SAndroid Build Coastguard Worker            aw = torch.jit._awaitable(C_wait_impl, C(torch.zeros(2), torch.ones(2)))
133*da0073e9SAndroid Build Coastguard Worker            _a = torch.eye(2)
134*da0073e9SAndroid Build Coastguard Worker            c_wait_impl_res = torch.jit._awaitable_wait(aw)
135*da0073e9SAndroid Build Coastguard Worker            return _a + c_wait_impl_res + x
136*da0073e9SAndroid Build Coastguard Worker
137*da0073e9SAndroid Build Coastguard Worker        inp = torch.ones(2)
138*da0073e9SAndroid Build Coastguard Worker
139*da0073e9SAndroid Build Coastguard Worker        sm = torch.jit.script(fn)
140*da0073e9SAndroid Build Coastguard Worker        out = fn(inp)
141*da0073e9SAndroid Build Coastguard Worker        script_out = sm(inp)
142*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.allclose(torch.eye(2) + 2 * torch.ones(2), script_out))
143*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.allclose(script_out, out))
144*da0073e9SAndroid Build Coastguard Worker
145*da0073e9SAndroid Build Coastguard Worker    def test_await_class_return(self):
146*da0073e9SAndroid Build Coastguard Worker        class C:
147*da0073e9SAndroid Build Coastguard Worker            __slots__ = ["a", "b"]
148*da0073e9SAndroid Build Coastguard Worker
149*da0073e9SAndroid Build Coastguard Worker            def __init__(self, a: Tensor, b: Tensor):
150*da0073e9SAndroid Build Coastguard Worker                self.a = a
151*da0073e9SAndroid Build Coastguard Worker                self.b = b
152*da0073e9SAndroid Build Coastguard Worker
153*da0073e9SAndroid Build Coastguard Worker        make_global(C)
154*da0073e9SAndroid Build Coastguard Worker
155*da0073e9SAndroid Build Coastguard Worker        # Can not stay in the class as Jit does not support Recursive annotations
156*da0073e9SAndroid Build Coastguard Worker        # (self in wait_impl can not be annotated as C as C is not defined by this time)
157*da0073e9SAndroid Build Coastguard Worker        def C_wait_impl(self: C) -> C:
158*da0073e9SAndroid Build Coastguard Worker            return C(self.a * 2, self.b * 3)
159*da0073e9SAndroid Build Coastguard Worker
160*da0073e9SAndroid Build Coastguard Worker        def fn_arg_C(x: C) -> Tensor:
161*da0073e9SAndroid Build Coastguard Worker            return x.a + x.b
162*da0073e9SAndroid Build Coastguard Worker
163*da0073e9SAndroid Build Coastguard Worker        def fn(x: Tensor):
164*da0073e9SAndroid Build Coastguard Worker            aw: Await[C] = torch.jit._awaitable(C_wait_impl, C(x, x))
165*da0073e9SAndroid Build Coastguard Worker            _a = torch.eye(2)
166*da0073e9SAndroid Build Coastguard Worker            y = fn_arg_C(torch.jit._awaitable_wait(aw))
167*da0073e9SAndroid Build Coastguard Worker            return _a + y + x
168*da0073e9SAndroid Build Coastguard Worker
169*da0073e9SAndroid Build Coastguard Worker        inp = torch.ones(2)
170*da0073e9SAndroid Build Coastguard Worker
171*da0073e9SAndroid Build Coastguard Worker        sm = torch.jit.script(fn)
172*da0073e9SAndroid Build Coastguard Worker        out = fn(inp)
173*da0073e9SAndroid Build Coastguard Worker        script_out = sm(inp)
174*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.allclose(torch.eye(2) + 6 * torch.ones(2), script_out))
175*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.allclose(script_out, out))
176*da0073e9SAndroid Build Coastguard Worker        self.assertGraphContainsExactly(
177*da0073e9SAndroid Build Coastguard Worker            sm.graph, kind="prim::awaitable_wait", num_kind_nodes=1
178*da0073e9SAndroid Build Coastguard Worker        )
179*da0073e9SAndroid Build Coastguard Worker
180*da0073e9SAndroid Build Coastguard Worker    def test_await_getattr_implicit_convertion(self):
181*da0073e9SAndroid Build Coastguard Worker        class C:
182*da0073e9SAndroid Build Coastguard Worker            def __init__(self, a: Tensor, b: Tensor):
183*da0073e9SAndroid Build Coastguard Worker                self._a = a
184*da0073e9SAndroid Build Coastguard Worker                self._b = b
185*da0073e9SAndroid Build Coastguard Worker
186*da0073e9SAndroid Build Coastguard Worker            def b(self):
187*da0073e9SAndroid Build Coastguard Worker                return self._b
188*da0073e9SAndroid Build Coastguard Worker
189*da0073e9SAndroid Build Coastguard Worker        make_global(C)
190*da0073e9SAndroid Build Coastguard Worker
191*da0073e9SAndroid Build Coastguard Worker        # Can not stay in the class as Jit does not support Recursive annotations
192*da0073e9SAndroid Build Coastguard Worker        # (self in wait_impl can not be annotated as C as C is not defined by this time)
193*da0073e9SAndroid Build Coastguard Worker        def C_wait_impl(self: C) -> C:
194*da0073e9SAndroid Build Coastguard Worker            return C(self._a * 2, self._b * 3)
195*da0073e9SAndroid Build Coastguard Worker
196*da0073e9SAndroid Build Coastguard Worker        def fn_arg_C(x: C) -> Tensor:
197*da0073e9SAndroid Build Coastguard Worker            return x._a + x._b
198*da0073e9SAndroid Build Coastguard Worker
199*da0073e9SAndroid Build Coastguard Worker        def fn(x: Tensor):
200*da0073e9SAndroid Build Coastguard Worker            aw: Await[C] = torch.jit._awaitable(C_wait_impl, C(x, x))
201*da0073e9SAndroid Build Coastguard Worker            _a = torch.eye(2)
202*da0073e9SAndroid Build Coastguard Worker            ai = aw._a
203*da0073e9SAndroid Build Coastguard Worker            awb = aw.b()
204*da0073e9SAndroid Build Coastguard Worker            c = C(2 * x, 2 * x)
205*da0073e9SAndroid Build Coastguard Worker            return _a + ai + x + c._a + c.b()
206*da0073e9SAndroid Build Coastguard Worker
207*da0073e9SAndroid Build Coastguard Worker        inp = torch.ones(2)
208*da0073e9SAndroid Build Coastguard Worker
209*da0073e9SAndroid Build Coastguard Worker        sm = torch.jit.script(fn)
210*da0073e9SAndroid Build Coastguard Worker        out = fn(inp)
211*da0073e9SAndroid Build Coastguard Worker        script_out = sm(inp)
212*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.allclose(torch.eye(2) + 7 * torch.ones(2), script_out))
213*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.allclose(script_out, out))
214*da0073e9SAndroid Build Coastguard Worker        self.assertGraphContainsExactly(
215*da0073e9SAndroid Build Coastguard Worker            sm.graph, kind="prim::awaitable_wait", num_kind_nodes=2
216*da0073e9SAndroid Build Coastguard Worker        )
217*da0073e9SAndroid Build Coastguard Worker
218*da0073e9SAndroid Build Coastguard Worker    def test_await_nested(self):
219*da0073e9SAndroid Build Coastguard Worker        class C:
220*da0073e9SAndroid Build Coastguard Worker            def __init__(self, a: Tensor, b: Tensor):
221*da0073e9SAndroid Build Coastguard Worker                self.__a = a
222*da0073e9SAndroid Build Coastguard Worker                self.__b = b
223*da0073e9SAndroid Build Coastguard Worker
224*da0073e9SAndroid Build Coastguard Worker            def a(self) -> Tensor:
225*da0073e9SAndroid Build Coastguard Worker                return self.__a
226*da0073e9SAndroid Build Coastguard Worker
227*da0073e9SAndroid Build Coastguard Worker        make_global(C)
228*da0073e9SAndroid Build Coastguard Worker
229*da0073e9SAndroid Build Coastguard Worker        def delayed(c: C) -> Await[Tensor]:
230*da0073e9SAndroid Build Coastguard Worker            return torch.jit._awaitable_nowait(3 * c.a())
231*da0073e9SAndroid Build Coastguard Worker
232*da0073e9SAndroid Build Coastguard Worker        def fn(x: Tensor) -> Await[Await[Tensor]]:
233*da0073e9SAndroid Build Coastguard Worker            return torch.jit._awaitable(delayed, C(2 * x, x))
234*da0073e9SAndroid Build Coastguard Worker
235*da0073e9SAndroid Build Coastguard Worker        def main(x: Tensor) -> Tensor:
236*da0073e9SAndroid Build Coastguard Worker            awaw = fn(x)
237*da0073e9SAndroid Build Coastguard Worker            return torch.jit._awaitable_wait(torch.jit._awaitable_wait(awaw))
238*da0073e9SAndroid Build Coastguard Worker
239*da0073e9SAndroid Build Coastguard Worker        inp = torch.eye(2)
240*da0073e9SAndroid Build Coastguard Worker
241*da0073e9SAndroid Build Coastguard Worker        sm = torch.jit.script(main)
242*da0073e9SAndroid Build Coastguard Worker        out = main(inp)
243*da0073e9SAndroid Build Coastguard Worker        script_out = sm(inp)
244*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.allclose(6 * torch.eye(2), script_out))
245*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.allclose(script_out, out))
246*da0073e9SAndroid Build Coastguard Worker
247*da0073e9SAndroid Build Coastguard Worker    def test_eager_await_non_scriptable(self):
248*da0073e9SAndroid Build Coastguard Worker        # Tree type can not be compiled (Recursive type)
249*da0073e9SAndroid Build Coastguard Worker        class Tree:
250*da0073e9SAndroid Build Coastguard Worker            def __init__(self, v):
251*da0073e9SAndroid Build Coastguard Worker                self.parent = torch.jit.annotate(Optional[Tree], None)
252*da0073e9SAndroid Build Coastguard Worker                self.v = v
253*da0073e9SAndroid Build Coastguard Worker
254*da0073e9SAndroid Build Coastguard Worker        make_global(Tree)
255*da0073e9SAndroid Build Coastguard Worker
256*da0073e9SAndroid Build Coastguard Worker        def delayed(t: Tree):
257*da0073e9SAndroid Build Coastguard Worker            t.v = t.v + 1
258*da0073e9SAndroid Build Coastguard Worker            return t
259*da0073e9SAndroid Build Coastguard Worker
260*da0073e9SAndroid Build Coastguard Worker        aw = torch.jit._awaitable(delayed, Tree(2))
261*da0073e9SAndroid Build Coastguard Worker        t = torch.jit._awaitable_wait(aw)
262*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(t.v == 3)
263*da0073e9SAndroid Build Coastguard Worker
264*da0073e9SAndroid Build Coastguard Worker    def test_await_isinstance(self):
265*da0073e9SAndroid Build Coastguard Worker        def delayed(x: Tensor) -> Tensor:
266*da0073e9SAndroid Build Coastguard Worker            return 2 * (x + 1)
267*da0073e9SAndroid Build Coastguard Worker
268*da0073e9SAndroid Build Coastguard Worker        def main(x: Tensor) -> Tensor:
269*da0073e9SAndroid Build Coastguard Worker            aw = torch.jit._awaitable(delayed, x)
270*da0073e9SAndroid Build Coastguard Worker            if torch.jit.is_scripting():
271*da0073e9SAndroid Build Coastguard Worker                assert isinstance(aw, torch.jit._Await)
272*da0073e9SAndroid Build Coastguard Worker            return torch.jit._awaitable_wait(aw)
273*da0073e9SAndroid Build Coastguard Worker
274*da0073e9SAndroid Build Coastguard Worker        inp = torch.eye(2)
275*da0073e9SAndroid Build Coastguard Worker
276*da0073e9SAndroid Build Coastguard Worker        sm = torch.jit.script(main)
277*da0073e9SAndroid Build Coastguard Worker        out = main(inp)
278*da0073e9SAndroid Build Coastguard Worker        script_out = sm(inp)
279*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(
280*da0073e9SAndroid Build Coastguard Worker            torch.allclose(2 * torch.eye(2) + 2 * torch.ones(2), script_out)
281*da0073e9SAndroid Build Coastguard Worker        )
282*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.allclose(script_out, out))
283*da0073e9SAndroid Build Coastguard Worker
284*da0073e9SAndroid Build Coastguard Worker    def test_await_eager_lazy(self):
285*da0073e9SAndroid Build Coastguard Worker        def delayed(x: Tensor) -> Tensor:
286*da0073e9SAndroid Build Coastguard Worker            return 2 * (x + 1)
287*da0073e9SAndroid Build Coastguard Worker
288*da0073e9SAndroid Build Coastguard Worker        t = torch.ones(2, dtype=torch.int64)
289*da0073e9SAndroid Build Coastguard Worker        aw = torch.jit._awaitable(delayed, t)
290*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(isinstance(aw, torch._C._Await))
291*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(t.dtype == aw.dtype)
292*da0073e9SAndroid Build Coastguard Worker
293*da0073e9SAndroid Build Coastguard Worker    def test_await_out_of_interpreter(self):
294*da0073e9SAndroid Build Coastguard Worker        def delayed(x: Tensor) -> Tensor:
295*da0073e9SAndroid Build Coastguard Worker            return 2 * (x + 1)
296*da0073e9SAndroid Build Coastguard Worker
297*da0073e9SAndroid Build Coastguard Worker        def main(x: Tensor) -> Await[Tensor]:
298*da0073e9SAndroid Build Coastguard Worker            aw = torch.jit._awaitable(delayed, x)
299*da0073e9SAndroid Build Coastguard Worker            return aw
300*da0073e9SAndroid Build Coastguard Worker
301*da0073e9SAndroid Build Coastguard Worker        inp = torch.eye(2)
302*da0073e9SAndroid Build Coastguard Worker
303*da0073e9SAndroid Build Coastguard Worker        sm = torch.jit.script(main)
304*da0073e9SAndroid Build Coastguard Worker        out_aw = main(inp)
305*da0073e9SAndroid Build Coastguard Worker        out = torch.jit._awaitable_wait(out_aw)
306*da0073e9SAndroid Build Coastguard Worker
307*da0073e9SAndroid Build Coastguard Worker        script_out_aw = sm(inp)
308*da0073e9SAndroid Build Coastguard Worker        script_out = torch.jit._awaitable_wait(script_out_aw)
309*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(
310*da0073e9SAndroid Build Coastguard Worker            torch.allclose(2 * torch.eye(2) + 2 * torch.ones(2), script_out)
311*da0073e9SAndroid Build Coastguard Worker        )
312*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.allclose(script_out, out))
313*da0073e9SAndroid Build Coastguard Worker
314*da0073e9SAndroid Build Coastguard Worker    def test_jit_trace(self):
315*da0073e9SAndroid Build Coastguard Worker        def gap(x: Tensor):
316*da0073e9SAndroid Build Coastguard Worker            return torch.relu(x) + torch.sin(x)
317*da0073e9SAndroid Build Coastguard Worker
318*da0073e9SAndroid Build Coastguard Worker        def delayed(x: Tensor) -> Tensor:
319*da0073e9SAndroid Build Coastguard Worker            return 2 * (torch.cos(x) + 1)
320*da0073e9SAndroid Build Coastguard Worker
321*da0073e9SAndroid Build Coastguard Worker        def main(x: Tensor, y: Tensor) -> Tensor:
322*da0073e9SAndroid Build Coastguard Worker            aw = torch.jit._awaitable(delayed, x)
323*da0073e9SAndroid Build Coastguard Worker            z = gap(y)
324*da0073e9SAndroid Build Coastguard Worker            k = torch.jit._awaitable_wait(aw)
325*da0073e9SAndroid Build Coastguard Worker            return y + k
326*da0073e9SAndroid Build Coastguard Worker
327*da0073e9SAndroid Build Coastguard Worker        inp = torch.randn(2)
328*da0073e9SAndroid Build Coastguard Worker        tm = torch.jit.trace(main, (inp, inp))
329*da0073e9SAndroid Build Coastguard Worker        inp_check = torch.ones(2)
330*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(main(inp_check, inp_check), tm(inp_check, inp_check))
331*da0073e9SAndroid Build Coastguard Worker
332*da0073e9SAndroid Build Coastguard Worker    def test_await_multiout_save(self):
333*da0073e9SAndroid Build Coastguard Worker        def gap(x: Tensor):
334*da0073e9SAndroid Build Coastguard Worker            return torch.relu(x) + torch.sin(x)
335*da0073e9SAndroid Build Coastguard Worker
336*da0073e9SAndroid Build Coastguard Worker        def delayed(x: Tensor) -> Tuple[Tensor, List[Tensor]]:
337*da0073e9SAndroid Build Coastguard Worker            l = [x * i for i in range(5)]
338*da0073e9SAndroid Build Coastguard Worker            return (100 * x, l)
339*da0073e9SAndroid Build Coastguard Worker
340*da0073e9SAndroid Build Coastguard Worker        def main(x: Tensor) -> Tensor:
341*da0073e9SAndroid Build Coastguard Worker            aw = torch.jit._awaitable(delayed, x)
342*da0073e9SAndroid Build Coastguard Worker            z = gap(x)
343*da0073e9SAndroid Build Coastguard Worker            (_, l) = torch.jit._awaitable_wait(aw)
344*da0073e9SAndroid Build Coastguard Worker            return l[3] + z
345*da0073e9SAndroid Build Coastguard Worker
346*da0073e9SAndroid Build Coastguard Worker        inp = torch.eye(2)
347*da0073e9SAndroid Build Coastguard Worker
348*da0073e9SAndroid Build Coastguard Worker        sm = torch.jit.script(main)
349*da0073e9SAndroid Build Coastguard Worker        out = main(inp)
350*da0073e9SAndroid Build Coastguard Worker        script_out = sm(inp)
351*da0073e9SAndroid Build Coastguard Worker        expected = 4.8415 * torch.eye(2)
352*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.allclose(expected, script_out))
353*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.allclose(script_out, out))
354*da0073e9SAndroid Build Coastguard Worker
355*da0073e9SAndroid Build Coastguard Worker        iofile = io.BytesIO()
356*da0073e9SAndroid Build Coastguard Worker        torch.jit.save(sm, iofile)
357*da0073e9SAndroid Build Coastguard Worker        iofile.seek(0)
358*da0073e9SAndroid Build Coastguard Worker        sm = torch.jit.load(iofile)
359*da0073e9SAndroid Build Coastguard Worker        script_out_load = sm(inp)
360*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.allclose(expected, script_out_load))
361*da0073e9SAndroid Build Coastguard Worker
362*da0073e9SAndroid Build Coastguard Worker    def test_await_func_arg(self):
363*da0073e9SAndroid Build Coastguard Worker        def gap(x: Tensor):
364*da0073e9SAndroid Build Coastguard Worker            return torch.relu(x) + torch.sin(x)
365*da0073e9SAndroid Build Coastguard Worker
366*da0073e9SAndroid Build Coastguard Worker        def delayed(x: Tensor) -> Tensor:
367*da0073e9SAndroid Build Coastguard Worker            return -1 * x
368*da0073e9SAndroid Build Coastguard Worker
369*da0073e9SAndroid Build Coastguard Worker        def fn(aw: Await[Tensor]) -> Tensor:
370*da0073e9SAndroid Build Coastguard Worker            return 3 * torch.jit._awaitable_wait(aw)
371*da0073e9SAndroid Build Coastguard Worker
372*da0073e9SAndroid Build Coastguard Worker        def main(x: Tensor) -> Tensor:
373*da0073e9SAndroid Build Coastguard Worker            aw = torch.jit._awaitable(delayed, x)
374*da0073e9SAndroid Build Coastguard Worker            z = gap(x)
375*da0073e9SAndroid Build Coastguard Worker            y = fn(aw)
376*da0073e9SAndroid Build Coastguard Worker            return y + x
377*da0073e9SAndroid Build Coastguard Worker
378*da0073e9SAndroid Build Coastguard Worker        inp = torch.eye(2)
379*da0073e9SAndroid Build Coastguard Worker
380*da0073e9SAndroid Build Coastguard Worker        sm = torch.jit.script(main)
381*da0073e9SAndroid Build Coastguard Worker        out = main(inp)
382*da0073e9SAndroid Build Coastguard Worker        script_out = sm(inp)
383*da0073e9SAndroid Build Coastguard Worker        expected = -2 * torch.eye(2)
384*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.allclose(expected, script_out))
385*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.allclose(script_out, out))
386*da0073e9SAndroid Build Coastguard Worker
387*da0073e9SAndroid Build Coastguard Worker        iofile = io.BytesIO()
388*da0073e9SAndroid Build Coastguard Worker        torch.jit.save(sm, iofile)
389*da0073e9SAndroid Build Coastguard Worker        iofile.seek(0)
390*da0073e9SAndroid Build Coastguard Worker        sm = torch.jit.load(iofile)
391*da0073e9SAndroid Build Coastguard Worker        script_out_load = sm(inp)
392*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.allclose(expected, script_out_load))
393