xref: /aosp_15_r20/external/pytorch/test/jit/test_await.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: jit"]
2
3import io
4from typing import List, Optional, Tuple
5
6import torch
7from torch import Tensor
8from torch._awaits import _Await as Await
9from torch.testing._internal.jit_utils import JitTestCase, make_global
10
11
12class TestAwait(JitTestCase):
13    def test_await_python(self):
14        def foo(x: int) -> int:
15            return x + 13
16
17        aw: Await[int] = torch.jit._awaitable(foo, 13)
18        self.assertTrue(aw.fn()(*aw.args()) == torch.jit._awaitable_wait(aw))
19        nw = torch.jit._awaitable_nowait(33)
20        self.assertTrue(nw.is_nowait())
21        self.assertTrue(nw.args() == (33,))
22
23    def test_await_type_python(self):
24        def foo() -> Tensor:
25            return torch.randn()
26
27        awaits = torch.jit.annotate(List[Await[Tensor]], [])
28        awaits.append(torch.jit._awaitable(foo))
29
30    def test_script(self):
31        def delayed(z: int) -> int:
32            return z + 3
33
34        def fn(x: Tensor):
35            aw: Await[int] = torch.jit._awaitable(delayed, 99)
36            a = torch.eye(2)
37            b = torch.jit._awaitable_wait(aw)
38            return a + b + x
39
40        inp = torch.zeros(2)
41
42        sm = torch.jit.script(fn)
43        out = fn(inp)
44        script_out = sm(inp)
45        self.assertTrue(torch.allclose(torch.eye(2) + 102, script_out))
46        self.assertTrue(torch.allclose(script_out, out))
47
48    def test_nowait(self):
49        def fn(x: Tensor):
50            aw = torch.jit._awaitable_nowait(13)
51            a = torch.eye(2)
52            b = torch.jit._awaitable_wait(aw)
53            return a + b + x
54
55        inp = torch.zeros(2)
56
57        sm = torch.jit.script(fn)
58        out = fn(inp)
59        script_out = sm(inp)
60        self.assertTrue(torch.allclose(torch.eye(2) + 13, script_out))
61        self.assertTrue(torch.allclose(script_out, out))
62
63    def test_nowait_class(self):
64        class C:
65            def __init__(self, a: Tensor, b: Tensor):
66                self._a = a
67                self._b = b
68
69            def a(self) -> Tensor:
70                return self._a
71
72        def fn(x: Tensor):
73            aw = torch.jit._awaitable_nowait(C(torch.zeros(2), torch.ones(2)))
74            _a = torch.eye(2)
75            c = torch.jit._awaitable_wait(aw)
76            return _a + c.a() + x
77
78        make_global(C)
79        inp = torch.zeros(2)
80
81        sm = torch.jit.script(fn)
82        out = fn(inp)
83        script_out = sm(inp)
84        self.assertTrue(torch.allclose(torch.eye(2), script_out))
85        self.assertTrue(torch.allclose(script_out, out))
86
87    def test_await_class_arg(self):
88        class C:
89            def __init__(self, a: Tensor, b: Tensor):
90                self.__a = a
91                self.__b = b
92
93            def a(self) -> Tensor:
94                return self.__a
95
96        make_global(C)
97
98        def delayed(c: C) -> Tensor:
99            return c.a()
100
101        def fn(x: Tensor):
102            c = C(torch.zeros(2), torch.ones(2))
103            aw = torch.jit._awaitable(delayed, c)
104            _a = torch.eye(2)
105            c2_t = torch.jit._awaitable_wait(aw)
106            return _a + c2_t + x
107
108        inp = torch.zeros(2)
109
110        sm = torch.jit.script(fn)
111        out = fn(inp)
112        script_out = sm(inp)
113        self.assertTrue(torch.allclose(torch.eye(2), script_out))
114        self.assertTrue(torch.allclose(script_out, out))
115
116    def test_awaitable_to_await(self):
117        class C:
118            __slots__ = ["_a", "_b"]
119
120            def __init__(self, a: Tensor, b: Tensor):
121                self._a = a
122                self._b = b
123
124        make_global(C)
125
126        # Can not stay in the class as Jit does not support Recursive annotations
127        # (self in wait_impl can not be annotated as C as C is not defined by this time)
128        def C_wait_impl(self: C):
129            return self._a + self._b
130
131        def fn(x: Tensor):
132            aw = torch.jit._awaitable(C_wait_impl, C(torch.zeros(2), torch.ones(2)))
133            _a = torch.eye(2)
134            c_wait_impl_res = torch.jit._awaitable_wait(aw)
135            return _a + c_wait_impl_res + x
136
137        inp = torch.ones(2)
138
139        sm = torch.jit.script(fn)
140        out = fn(inp)
141        script_out = sm(inp)
142        self.assertTrue(torch.allclose(torch.eye(2) + 2 * torch.ones(2), script_out))
143        self.assertTrue(torch.allclose(script_out, out))
144
145    def test_await_class_return(self):
146        class C:
147            __slots__ = ["a", "b"]
148
149            def __init__(self, a: Tensor, b: Tensor):
150                self.a = a
151                self.b = b
152
153        make_global(C)
154
155        # Can not stay in the class as Jit does not support Recursive annotations
156        # (self in wait_impl can not be annotated as C as C is not defined by this time)
157        def C_wait_impl(self: C) -> C:
158            return C(self.a * 2, self.b * 3)
159
160        def fn_arg_C(x: C) -> Tensor:
161            return x.a + x.b
162
163        def fn(x: Tensor):
164            aw: Await[C] = torch.jit._awaitable(C_wait_impl, C(x, x))
165            _a = torch.eye(2)
166            y = fn_arg_C(torch.jit._awaitable_wait(aw))
167            return _a + y + x
168
169        inp = torch.ones(2)
170
171        sm = torch.jit.script(fn)
172        out = fn(inp)
173        script_out = sm(inp)
174        self.assertTrue(torch.allclose(torch.eye(2) + 6 * torch.ones(2), script_out))
175        self.assertTrue(torch.allclose(script_out, out))
176        self.assertGraphContainsExactly(
177            sm.graph, kind="prim::awaitable_wait", num_kind_nodes=1
178        )
179
180    def test_await_getattr_implicit_convertion(self):
181        class C:
182            def __init__(self, a: Tensor, b: Tensor):
183                self._a = a
184                self._b = b
185
186            def b(self):
187                return self._b
188
189        make_global(C)
190
191        # Can not stay in the class as Jit does not support Recursive annotations
192        # (self in wait_impl can not be annotated as C as C is not defined by this time)
193        def C_wait_impl(self: C) -> C:
194            return C(self._a * 2, self._b * 3)
195
196        def fn_arg_C(x: C) -> Tensor:
197            return x._a + x._b
198
199        def fn(x: Tensor):
200            aw: Await[C] = torch.jit._awaitable(C_wait_impl, C(x, x))
201            _a = torch.eye(2)
202            ai = aw._a
203            awb = aw.b()
204            c = C(2 * x, 2 * x)
205            return _a + ai + x + c._a + c.b()
206
207        inp = torch.ones(2)
208
209        sm = torch.jit.script(fn)
210        out = fn(inp)
211        script_out = sm(inp)
212        self.assertTrue(torch.allclose(torch.eye(2) + 7 * torch.ones(2), script_out))
213        self.assertTrue(torch.allclose(script_out, out))
214        self.assertGraphContainsExactly(
215            sm.graph, kind="prim::awaitable_wait", num_kind_nodes=2
216        )
217
218    def test_await_nested(self):
219        class C:
220            def __init__(self, a: Tensor, b: Tensor):
221                self.__a = a
222                self.__b = b
223
224            def a(self) -> Tensor:
225                return self.__a
226
227        make_global(C)
228
229        def delayed(c: C) -> Await[Tensor]:
230            return torch.jit._awaitable_nowait(3 * c.a())
231
232        def fn(x: Tensor) -> Await[Await[Tensor]]:
233            return torch.jit._awaitable(delayed, C(2 * x, x))
234
235        def main(x: Tensor) -> Tensor:
236            awaw = fn(x)
237            return torch.jit._awaitable_wait(torch.jit._awaitable_wait(awaw))
238
239        inp = torch.eye(2)
240
241        sm = torch.jit.script(main)
242        out = main(inp)
243        script_out = sm(inp)
244        self.assertTrue(torch.allclose(6 * torch.eye(2), script_out))
245        self.assertTrue(torch.allclose(script_out, out))
246
247    def test_eager_await_non_scriptable(self):
248        # Tree type can not be compiled (Recursive type)
249        class Tree:
250            def __init__(self, v):
251                self.parent = torch.jit.annotate(Optional[Tree], None)
252                self.v = v
253
254        make_global(Tree)
255
256        def delayed(t: Tree):
257            t.v = t.v + 1
258            return t
259
260        aw = torch.jit._awaitable(delayed, Tree(2))
261        t = torch.jit._awaitable_wait(aw)
262        self.assertTrue(t.v == 3)
263
264    def test_await_isinstance(self):
265        def delayed(x: Tensor) -> Tensor:
266            return 2 * (x + 1)
267
268        def main(x: Tensor) -> Tensor:
269            aw = torch.jit._awaitable(delayed, x)
270            if torch.jit.is_scripting():
271                assert isinstance(aw, torch.jit._Await)
272            return torch.jit._awaitable_wait(aw)
273
274        inp = torch.eye(2)
275
276        sm = torch.jit.script(main)
277        out = main(inp)
278        script_out = sm(inp)
279        self.assertTrue(
280            torch.allclose(2 * torch.eye(2) + 2 * torch.ones(2), script_out)
281        )
282        self.assertTrue(torch.allclose(script_out, out))
283
284    def test_await_eager_lazy(self):
285        def delayed(x: Tensor) -> Tensor:
286            return 2 * (x + 1)
287
288        t = torch.ones(2, dtype=torch.int64)
289        aw = torch.jit._awaitable(delayed, t)
290        self.assertTrue(isinstance(aw, torch._C._Await))
291        self.assertTrue(t.dtype == aw.dtype)
292
293    def test_await_out_of_interpreter(self):
294        def delayed(x: Tensor) -> Tensor:
295            return 2 * (x + 1)
296
297        def main(x: Tensor) -> Await[Tensor]:
298            aw = torch.jit._awaitable(delayed, x)
299            return aw
300
301        inp = torch.eye(2)
302
303        sm = torch.jit.script(main)
304        out_aw = main(inp)
305        out = torch.jit._awaitable_wait(out_aw)
306
307        script_out_aw = sm(inp)
308        script_out = torch.jit._awaitable_wait(script_out_aw)
309        self.assertTrue(
310            torch.allclose(2 * torch.eye(2) + 2 * torch.ones(2), script_out)
311        )
312        self.assertTrue(torch.allclose(script_out, out))
313
314    def test_jit_trace(self):
315        def gap(x: Tensor):
316            return torch.relu(x) + torch.sin(x)
317
318        def delayed(x: Tensor) -> Tensor:
319            return 2 * (torch.cos(x) + 1)
320
321        def main(x: Tensor, y: Tensor) -> Tensor:
322            aw = torch.jit._awaitable(delayed, x)
323            z = gap(y)
324            k = torch.jit._awaitable_wait(aw)
325            return y + k
326
327        inp = torch.randn(2)
328        tm = torch.jit.trace(main, (inp, inp))
329        inp_check = torch.ones(2)
330        self.assertEqual(main(inp_check, inp_check), tm(inp_check, inp_check))
331
332    def test_await_multiout_save(self):
333        def gap(x: Tensor):
334            return torch.relu(x) + torch.sin(x)
335
336        def delayed(x: Tensor) -> Tuple[Tensor, List[Tensor]]:
337            l = [x * i for i in range(5)]
338            return (100 * x, l)
339
340        def main(x: Tensor) -> Tensor:
341            aw = torch.jit._awaitable(delayed, x)
342            z = gap(x)
343            (_, l) = torch.jit._awaitable_wait(aw)
344            return l[3] + z
345
346        inp = torch.eye(2)
347
348        sm = torch.jit.script(main)
349        out = main(inp)
350        script_out = sm(inp)
351        expected = 4.8415 * torch.eye(2)
352        self.assertTrue(torch.allclose(expected, script_out))
353        self.assertTrue(torch.allclose(script_out, out))
354
355        iofile = io.BytesIO()
356        torch.jit.save(sm, iofile)
357        iofile.seek(0)
358        sm = torch.jit.load(iofile)
359        script_out_load = sm(inp)
360        self.assertTrue(torch.allclose(expected, script_out_load))
361
362    def test_await_func_arg(self):
363        def gap(x: Tensor):
364            return torch.relu(x) + torch.sin(x)
365
366        def delayed(x: Tensor) -> Tensor:
367            return -1 * x
368
369        def fn(aw: Await[Tensor]) -> Tensor:
370            return 3 * torch.jit._awaitable_wait(aw)
371
372        def main(x: Tensor) -> Tensor:
373            aw = torch.jit._awaitable(delayed, x)
374            z = gap(x)
375            y = fn(aw)
376            return y + x
377
378        inp = torch.eye(2)
379
380        sm = torch.jit.script(main)
381        out = main(inp)
382        script_out = sm(inp)
383        expected = -2 * torch.eye(2)
384        self.assertTrue(torch.allclose(expected, script_out))
385        self.assertTrue(torch.allclose(script_out, out))
386
387        iofile = io.BytesIO()
388        torch.jit.save(sm, iofile)
389        iofile.seek(0)
390        sm = torch.jit.load(iofile)
391        script_out_load = sm(inp)
392        self.assertTrue(torch.allclose(expected, script_out_load))
393