1# Owner(s): ["oncall: jit"] 2 3import io 4from typing import List, Optional, Tuple 5 6import torch 7from torch import Tensor 8from torch._awaits import _Await as Await 9from torch.testing._internal.jit_utils import JitTestCase, make_global 10 11 12class TestAwait(JitTestCase): 13 def test_await_python(self): 14 def foo(x: int) -> int: 15 return x + 13 16 17 aw: Await[int] = torch.jit._awaitable(foo, 13) 18 self.assertTrue(aw.fn()(*aw.args()) == torch.jit._awaitable_wait(aw)) 19 nw = torch.jit._awaitable_nowait(33) 20 self.assertTrue(nw.is_nowait()) 21 self.assertTrue(nw.args() == (33,)) 22 23 def test_await_type_python(self): 24 def foo() -> Tensor: 25 return torch.randn() 26 27 awaits = torch.jit.annotate(List[Await[Tensor]], []) 28 awaits.append(torch.jit._awaitable(foo)) 29 30 def test_script(self): 31 def delayed(z: int) -> int: 32 return z + 3 33 34 def fn(x: Tensor): 35 aw: Await[int] = torch.jit._awaitable(delayed, 99) 36 a = torch.eye(2) 37 b = torch.jit._awaitable_wait(aw) 38 return a + b + x 39 40 inp = torch.zeros(2) 41 42 sm = torch.jit.script(fn) 43 out = fn(inp) 44 script_out = sm(inp) 45 self.assertTrue(torch.allclose(torch.eye(2) + 102, script_out)) 46 self.assertTrue(torch.allclose(script_out, out)) 47 48 def test_nowait(self): 49 def fn(x: Tensor): 50 aw = torch.jit._awaitable_nowait(13) 51 a = torch.eye(2) 52 b = torch.jit._awaitable_wait(aw) 53 return a + b + x 54 55 inp = torch.zeros(2) 56 57 sm = torch.jit.script(fn) 58 out = fn(inp) 59 script_out = sm(inp) 60 self.assertTrue(torch.allclose(torch.eye(2) + 13, script_out)) 61 self.assertTrue(torch.allclose(script_out, out)) 62 63 def test_nowait_class(self): 64 class C: 65 def __init__(self, a: Tensor, b: Tensor): 66 self._a = a 67 self._b = b 68 69 def a(self) -> Tensor: 70 return self._a 71 72 def fn(x: Tensor): 73 aw = torch.jit._awaitable_nowait(C(torch.zeros(2), torch.ones(2))) 74 _a = torch.eye(2) 75 c = torch.jit._awaitable_wait(aw) 76 return _a + c.a() + x 77 78 make_global(C) 79 inp = torch.zeros(2) 80 81 sm = torch.jit.script(fn) 82 out = fn(inp) 83 script_out = sm(inp) 84 self.assertTrue(torch.allclose(torch.eye(2), script_out)) 85 self.assertTrue(torch.allclose(script_out, out)) 86 87 def test_await_class_arg(self): 88 class C: 89 def __init__(self, a: Tensor, b: Tensor): 90 self.__a = a 91 self.__b = b 92 93 def a(self) -> Tensor: 94 return self.__a 95 96 make_global(C) 97 98 def delayed(c: C) -> Tensor: 99 return c.a() 100 101 def fn(x: Tensor): 102 c = C(torch.zeros(2), torch.ones(2)) 103 aw = torch.jit._awaitable(delayed, c) 104 _a = torch.eye(2) 105 c2_t = torch.jit._awaitable_wait(aw) 106 return _a + c2_t + x 107 108 inp = torch.zeros(2) 109 110 sm = torch.jit.script(fn) 111 out = fn(inp) 112 script_out = sm(inp) 113 self.assertTrue(torch.allclose(torch.eye(2), script_out)) 114 self.assertTrue(torch.allclose(script_out, out)) 115 116 def test_awaitable_to_await(self): 117 class C: 118 __slots__ = ["_a", "_b"] 119 120 def __init__(self, a: Tensor, b: Tensor): 121 self._a = a 122 self._b = b 123 124 make_global(C) 125 126 # Can not stay in the class as Jit does not support Recursive annotations 127 # (self in wait_impl can not be annotated as C as C is not defined by this time) 128 def C_wait_impl(self: C): 129 return self._a + self._b 130 131 def fn(x: Tensor): 132 aw = torch.jit._awaitable(C_wait_impl, C(torch.zeros(2), torch.ones(2))) 133 _a = torch.eye(2) 134 c_wait_impl_res = torch.jit._awaitable_wait(aw) 135 return _a + c_wait_impl_res + x 136 137 inp = torch.ones(2) 138 139 sm = torch.jit.script(fn) 140 out = fn(inp) 141 script_out = sm(inp) 142 self.assertTrue(torch.allclose(torch.eye(2) + 2 * torch.ones(2), script_out)) 143 self.assertTrue(torch.allclose(script_out, out)) 144 145 def test_await_class_return(self): 146 class C: 147 __slots__ = ["a", "b"] 148 149 def __init__(self, a: Tensor, b: Tensor): 150 self.a = a 151 self.b = b 152 153 make_global(C) 154 155 # Can not stay in the class as Jit does not support Recursive annotations 156 # (self in wait_impl can not be annotated as C as C is not defined by this time) 157 def C_wait_impl(self: C) -> C: 158 return C(self.a * 2, self.b * 3) 159 160 def fn_arg_C(x: C) -> Tensor: 161 return x.a + x.b 162 163 def fn(x: Tensor): 164 aw: Await[C] = torch.jit._awaitable(C_wait_impl, C(x, x)) 165 _a = torch.eye(2) 166 y = fn_arg_C(torch.jit._awaitable_wait(aw)) 167 return _a + y + x 168 169 inp = torch.ones(2) 170 171 sm = torch.jit.script(fn) 172 out = fn(inp) 173 script_out = sm(inp) 174 self.assertTrue(torch.allclose(torch.eye(2) + 6 * torch.ones(2), script_out)) 175 self.assertTrue(torch.allclose(script_out, out)) 176 self.assertGraphContainsExactly( 177 sm.graph, kind="prim::awaitable_wait", num_kind_nodes=1 178 ) 179 180 def test_await_getattr_implicit_convertion(self): 181 class C: 182 def __init__(self, a: Tensor, b: Tensor): 183 self._a = a 184 self._b = b 185 186 def b(self): 187 return self._b 188 189 make_global(C) 190 191 # Can not stay in the class as Jit does not support Recursive annotations 192 # (self in wait_impl can not be annotated as C as C is not defined by this time) 193 def C_wait_impl(self: C) -> C: 194 return C(self._a * 2, self._b * 3) 195 196 def fn_arg_C(x: C) -> Tensor: 197 return x._a + x._b 198 199 def fn(x: Tensor): 200 aw: Await[C] = torch.jit._awaitable(C_wait_impl, C(x, x)) 201 _a = torch.eye(2) 202 ai = aw._a 203 awb = aw.b() 204 c = C(2 * x, 2 * x) 205 return _a + ai + x + c._a + c.b() 206 207 inp = torch.ones(2) 208 209 sm = torch.jit.script(fn) 210 out = fn(inp) 211 script_out = sm(inp) 212 self.assertTrue(torch.allclose(torch.eye(2) + 7 * torch.ones(2), script_out)) 213 self.assertTrue(torch.allclose(script_out, out)) 214 self.assertGraphContainsExactly( 215 sm.graph, kind="prim::awaitable_wait", num_kind_nodes=2 216 ) 217 218 def test_await_nested(self): 219 class C: 220 def __init__(self, a: Tensor, b: Tensor): 221 self.__a = a 222 self.__b = b 223 224 def a(self) -> Tensor: 225 return self.__a 226 227 make_global(C) 228 229 def delayed(c: C) -> Await[Tensor]: 230 return torch.jit._awaitable_nowait(3 * c.a()) 231 232 def fn(x: Tensor) -> Await[Await[Tensor]]: 233 return torch.jit._awaitable(delayed, C(2 * x, x)) 234 235 def main(x: Tensor) -> Tensor: 236 awaw = fn(x) 237 return torch.jit._awaitable_wait(torch.jit._awaitable_wait(awaw)) 238 239 inp = torch.eye(2) 240 241 sm = torch.jit.script(main) 242 out = main(inp) 243 script_out = sm(inp) 244 self.assertTrue(torch.allclose(6 * torch.eye(2), script_out)) 245 self.assertTrue(torch.allclose(script_out, out)) 246 247 def test_eager_await_non_scriptable(self): 248 # Tree type can not be compiled (Recursive type) 249 class Tree: 250 def __init__(self, v): 251 self.parent = torch.jit.annotate(Optional[Tree], None) 252 self.v = v 253 254 make_global(Tree) 255 256 def delayed(t: Tree): 257 t.v = t.v + 1 258 return t 259 260 aw = torch.jit._awaitable(delayed, Tree(2)) 261 t = torch.jit._awaitable_wait(aw) 262 self.assertTrue(t.v == 3) 263 264 def test_await_isinstance(self): 265 def delayed(x: Tensor) -> Tensor: 266 return 2 * (x + 1) 267 268 def main(x: Tensor) -> Tensor: 269 aw = torch.jit._awaitable(delayed, x) 270 if torch.jit.is_scripting(): 271 assert isinstance(aw, torch.jit._Await) 272 return torch.jit._awaitable_wait(aw) 273 274 inp = torch.eye(2) 275 276 sm = torch.jit.script(main) 277 out = main(inp) 278 script_out = sm(inp) 279 self.assertTrue( 280 torch.allclose(2 * torch.eye(2) + 2 * torch.ones(2), script_out) 281 ) 282 self.assertTrue(torch.allclose(script_out, out)) 283 284 def test_await_eager_lazy(self): 285 def delayed(x: Tensor) -> Tensor: 286 return 2 * (x + 1) 287 288 t = torch.ones(2, dtype=torch.int64) 289 aw = torch.jit._awaitable(delayed, t) 290 self.assertTrue(isinstance(aw, torch._C._Await)) 291 self.assertTrue(t.dtype == aw.dtype) 292 293 def test_await_out_of_interpreter(self): 294 def delayed(x: Tensor) -> Tensor: 295 return 2 * (x + 1) 296 297 def main(x: Tensor) -> Await[Tensor]: 298 aw = torch.jit._awaitable(delayed, x) 299 return aw 300 301 inp = torch.eye(2) 302 303 sm = torch.jit.script(main) 304 out_aw = main(inp) 305 out = torch.jit._awaitable_wait(out_aw) 306 307 script_out_aw = sm(inp) 308 script_out = torch.jit._awaitable_wait(script_out_aw) 309 self.assertTrue( 310 torch.allclose(2 * torch.eye(2) + 2 * torch.ones(2), script_out) 311 ) 312 self.assertTrue(torch.allclose(script_out, out)) 313 314 def test_jit_trace(self): 315 def gap(x: Tensor): 316 return torch.relu(x) + torch.sin(x) 317 318 def delayed(x: Tensor) -> Tensor: 319 return 2 * (torch.cos(x) + 1) 320 321 def main(x: Tensor, y: Tensor) -> Tensor: 322 aw = torch.jit._awaitable(delayed, x) 323 z = gap(y) 324 k = torch.jit._awaitable_wait(aw) 325 return y + k 326 327 inp = torch.randn(2) 328 tm = torch.jit.trace(main, (inp, inp)) 329 inp_check = torch.ones(2) 330 self.assertEqual(main(inp_check, inp_check), tm(inp_check, inp_check)) 331 332 def test_await_multiout_save(self): 333 def gap(x: Tensor): 334 return torch.relu(x) + torch.sin(x) 335 336 def delayed(x: Tensor) -> Tuple[Tensor, List[Tensor]]: 337 l = [x * i for i in range(5)] 338 return (100 * x, l) 339 340 def main(x: Tensor) -> Tensor: 341 aw = torch.jit._awaitable(delayed, x) 342 z = gap(x) 343 (_, l) = torch.jit._awaitable_wait(aw) 344 return l[3] + z 345 346 inp = torch.eye(2) 347 348 sm = torch.jit.script(main) 349 out = main(inp) 350 script_out = sm(inp) 351 expected = 4.8415 * torch.eye(2) 352 self.assertTrue(torch.allclose(expected, script_out)) 353 self.assertTrue(torch.allclose(script_out, out)) 354 355 iofile = io.BytesIO() 356 torch.jit.save(sm, iofile) 357 iofile.seek(0) 358 sm = torch.jit.load(iofile) 359 script_out_load = sm(inp) 360 self.assertTrue(torch.allclose(expected, script_out_load)) 361 362 def test_await_func_arg(self): 363 def gap(x: Tensor): 364 return torch.relu(x) + torch.sin(x) 365 366 def delayed(x: Tensor) -> Tensor: 367 return -1 * x 368 369 def fn(aw: Await[Tensor]) -> Tensor: 370 return 3 * torch.jit._awaitable_wait(aw) 371 372 def main(x: Tensor) -> Tensor: 373 aw = torch.jit._awaitable(delayed, x) 374 z = gap(x) 375 y = fn(aw) 376 return y + x 377 378 inp = torch.eye(2) 379 380 sm = torch.jit.script(main) 381 out = main(inp) 382 script_out = sm(inp) 383 expected = -2 * torch.eye(2) 384 self.assertTrue(torch.allclose(expected, script_out)) 385 self.assertTrue(torch.allclose(script_out, out)) 386 387 iofile = io.BytesIO() 388 torch.jit.save(sm, iofile) 389 iofile.seek(0) 390 sm = torch.jit.load(iofile) 391 script_out_load = sm(inp) 392 self.assertTrue(torch.allclose(expected, script_out_load)) 393