1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs 2*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: unknown"] 3*da0073e9SAndroid Build Coastguard Worker 4*da0073e9SAndroid Build Coastguard Workerimport threading 5*da0073e9SAndroid Build Coastguard Workerimport time 6*da0073e9SAndroid Build Coastguard Workerimport torch 7*da0073e9SAndroid Build Coastguard Workerimport unittest 8*da0073e9SAndroid Build Coastguard Workerfrom torch.futures import Future 9*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import IS_WINDOWS, TestCase, TemporaryFileName, run_tests 10*da0073e9SAndroid Build Coastguard Workerfrom typing import TypeVar 11*da0073e9SAndroid Build Coastguard Worker 12*da0073e9SAndroid Build Coastguard WorkerT = TypeVar("T") 13*da0073e9SAndroid Build Coastguard Worker 14*da0073e9SAndroid Build Coastguard Worker 15*da0073e9SAndroid Build Coastguard Workerdef add_one(fut): 16*da0073e9SAndroid Build Coastguard Worker return fut.wait() + 1 17*da0073e9SAndroid Build Coastguard Worker 18*da0073e9SAndroid Build Coastguard Worker 19*da0073e9SAndroid Build Coastguard Workerclass TestFuture(TestCase): 20*da0073e9SAndroid Build Coastguard Worker def test_set_exception(self) -> None: 21*da0073e9SAndroid Build Coastguard Worker # This test is to ensure errors can propagate across futures. 22*da0073e9SAndroid Build Coastguard Worker error_msg = "Intentional Value Error" 23*da0073e9SAndroid Build Coastguard Worker value_error = ValueError(error_msg) 24*da0073e9SAndroid Build Coastguard Worker 25*da0073e9SAndroid Build Coastguard Worker f = Future[T]() # type: ignore[valid-type] 26*da0073e9SAndroid Build Coastguard Worker # Set exception 27*da0073e9SAndroid Build Coastguard Worker f.set_exception(value_error) 28*da0073e9SAndroid Build Coastguard Worker # Exception should throw on wait 29*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, "Intentional"): 30*da0073e9SAndroid Build Coastguard Worker f.wait() 31*da0073e9SAndroid Build Coastguard Worker 32*da0073e9SAndroid Build Coastguard Worker # Exception should also throw on value 33*da0073e9SAndroid Build Coastguard Worker f = Future[T]() # type: ignore[valid-type] 34*da0073e9SAndroid Build Coastguard Worker f.set_exception(value_error) 35*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, "Intentional"): 36*da0073e9SAndroid Build Coastguard Worker f.value() 37*da0073e9SAndroid Build Coastguard Worker 38*da0073e9SAndroid Build Coastguard Worker def cb(fut): 39*da0073e9SAndroid Build Coastguard Worker fut.value() 40*da0073e9SAndroid Build Coastguard Worker 41*da0073e9SAndroid Build Coastguard Worker f = Future[T]() # type: ignore[valid-type] 42*da0073e9SAndroid Build Coastguard Worker f.set_exception(value_error) 43*da0073e9SAndroid Build Coastguard Worker 44*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Got the following error"): 45*da0073e9SAndroid Build Coastguard Worker cb_fut = f.then(cb) 46*da0073e9SAndroid Build Coastguard Worker cb_fut.wait() 47*da0073e9SAndroid Build Coastguard Worker 48*da0073e9SAndroid Build Coastguard Worker def test_set_exception_multithreading(self) -> None: 49*da0073e9SAndroid Build Coastguard Worker # Ensure errors can propagate when one thread waits on future result 50*da0073e9SAndroid Build Coastguard Worker # and the other sets it with an error. 51*da0073e9SAndroid Build Coastguard Worker error_msg = "Intentional Value Error" 52*da0073e9SAndroid Build Coastguard Worker value_error = ValueError(error_msg) 53*da0073e9SAndroid Build Coastguard Worker 54*da0073e9SAndroid Build Coastguard Worker def wait_future(f): 55*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, "Intentional"): 56*da0073e9SAndroid Build Coastguard Worker f.wait() 57*da0073e9SAndroid Build Coastguard Worker 58*da0073e9SAndroid Build Coastguard Worker f = Future[T]() # type: ignore[valid-type] 59*da0073e9SAndroid Build Coastguard Worker t = threading.Thread(target=wait_future, args=(f, )) 60*da0073e9SAndroid Build Coastguard Worker t.start() 61*da0073e9SAndroid Build Coastguard Worker f.set_exception(value_error) 62*da0073e9SAndroid Build Coastguard Worker t.join() 63*da0073e9SAndroid Build Coastguard Worker 64*da0073e9SAndroid Build Coastguard Worker def cb(fut): 65*da0073e9SAndroid Build Coastguard Worker fut.value() 66*da0073e9SAndroid Build Coastguard Worker 67*da0073e9SAndroid Build Coastguard Worker def then_future(f): 68*da0073e9SAndroid Build Coastguard Worker fut = f.then(cb) 69*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Got the following error"): 70*da0073e9SAndroid Build Coastguard Worker fut.wait() 71*da0073e9SAndroid Build Coastguard Worker 72*da0073e9SAndroid Build Coastguard Worker f = Future[T]() # type: ignore[valid-type] 73*da0073e9SAndroid Build Coastguard Worker t = threading.Thread(target=then_future, args=(f, )) 74*da0073e9SAndroid Build Coastguard Worker t.start() 75*da0073e9SAndroid Build Coastguard Worker f.set_exception(value_error) 76*da0073e9SAndroid Build Coastguard Worker t.join() 77*da0073e9SAndroid Build Coastguard Worker 78*da0073e9SAndroid Build Coastguard Worker def test_done(self) -> None: 79*da0073e9SAndroid Build Coastguard Worker f = Future[torch.Tensor]() 80*da0073e9SAndroid Build Coastguard Worker self.assertFalse(f.done()) 81*da0073e9SAndroid Build Coastguard Worker 82*da0073e9SAndroid Build Coastguard Worker f.set_result(torch.ones(2, 2)) 83*da0073e9SAndroid Build Coastguard Worker self.assertTrue(f.done()) 84*da0073e9SAndroid Build Coastguard Worker 85*da0073e9SAndroid Build Coastguard Worker def test_done_exception(self) -> None: 86*da0073e9SAndroid Build Coastguard Worker err_msg = "Intentional Value Error" 87*da0073e9SAndroid Build Coastguard Worker 88*da0073e9SAndroid Build Coastguard Worker def raise_exception(unused_future): 89*da0073e9SAndroid Build Coastguard Worker raise RuntimeError(err_msg) 90*da0073e9SAndroid Build Coastguard Worker 91*da0073e9SAndroid Build Coastguard Worker f1 = Future[torch.Tensor]() 92*da0073e9SAndroid Build Coastguard Worker self.assertFalse(f1.done()) 93*da0073e9SAndroid Build Coastguard Worker f1.set_result(torch.ones(2, 2)) 94*da0073e9SAndroid Build Coastguard Worker self.assertTrue(f1.done()) 95*da0073e9SAndroid Build Coastguard Worker 96*da0073e9SAndroid Build Coastguard Worker f2 = f1.then(raise_exception) 97*da0073e9SAndroid Build Coastguard Worker self.assertTrue(f2.done()) 98*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, err_msg): 99*da0073e9SAndroid Build Coastguard Worker f2.wait() 100*da0073e9SAndroid Build Coastguard Worker 101*da0073e9SAndroid Build Coastguard Worker def test_wait(self) -> None: 102*da0073e9SAndroid Build Coastguard Worker f = Future[torch.Tensor]() 103*da0073e9SAndroid Build Coastguard Worker f.set_result(torch.ones(2, 2)) 104*da0073e9SAndroid Build Coastguard Worker 105*da0073e9SAndroid Build Coastguard Worker self.assertEqual(f.wait(), torch.ones(2, 2)) 106*da0073e9SAndroid Build Coastguard Worker 107*da0073e9SAndroid Build Coastguard Worker def test_wait_multi_thread(self) -> None: 108*da0073e9SAndroid Build Coastguard Worker 109*da0073e9SAndroid Build Coastguard Worker def slow_set_future(fut, value): 110*da0073e9SAndroid Build Coastguard Worker time.sleep(0.5) 111*da0073e9SAndroid Build Coastguard Worker fut.set_result(value) 112*da0073e9SAndroid Build Coastguard Worker 113*da0073e9SAndroid Build Coastguard Worker f = Future[torch.Tensor]() 114*da0073e9SAndroid Build Coastguard Worker 115*da0073e9SAndroid Build Coastguard Worker t = threading.Thread(target=slow_set_future, args=(f, torch.ones(2, 2))) 116*da0073e9SAndroid Build Coastguard Worker t.start() 117*da0073e9SAndroid Build Coastguard Worker 118*da0073e9SAndroid Build Coastguard Worker self.assertEqual(f.wait(), torch.ones(2, 2)) 119*da0073e9SAndroid Build Coastguard Worker t.join() 120*da0073e9SAndroid Build Coastguard Worker 121*da0073e9SAndroid Build Coastguard Worker def test_mark_future_twice(self) -> None: 122*da0073e9SAndroid Build Coastguard Worker fut = Future[int]() 123*da0073e9SAndroid Build Coastguard Worker fut.set_result(1) 124*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 125*da0073e9SAndroid Build Coastguard Worker RuntimeError, 126*da0073e9SAndroid Build Coastguard Worker "Future can only be marked completed once" 127*da0073e9SAndroid Build Coastguard Worker ): 128*da0073e9SAndroid Build Coastguard Worker fut.set_result(1) 129*da0073e9SAndroid Build Coastguard Worker 130*da0073e9SAndroid Build Coastguard Worker def test_pickle_future(self): 131*da0073e9SAndroid Build Coastguard Worker fut = Future[int]() 132*da0073e9SAndroid Build Coastguard Worker errMsg = "Can not pickle torch.futures.Future" 133*da0073e9SAndroid Build Coastguard Worker with TemporaryFileName() as fname: 134*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, errMsg): 135*da0073e9SAndroid Build Coastguard Worker torch.save(fut, fname) 136*da0073e9SAndroid Build Coastguard Worker 137*da0073e9SAndroid Build Coastguard Worker def test_then(self): 138*da0073e9SAndroid Build Coastguard Worker fut = Future[torch.Tensor]() 139*da0073e9SAndroid Build Coastguard Worker then_fut = fut.then(lambda x: x.wait() + 1) 140*da0073e9SAndroid Build Coastguard Worker 141*da0073e9SAndroid Build Coastguard Worker fut.set_result(torch.ones(2, 2)) 142*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fut.wait(), torch.ones(2, 2)) 143*da0073e9SAndroid Build Coastguard Worker self.assertEqual(then_fut.wait(), torch.ones(2, 2) + 1) 144*da0073e9SAndroid Build Coastguard Worker 145*da0073e9SAndroid Build Coastguard Worker def test_chained_then(self): 146*da0073e9SAndroid Build Coastguard Worker fut = Future[torch.Tensor]() 147*da0073e9SAndroid Build Coastguard Worker futs = [] 148*da0073e9SAndroid Build Coastguard Worker last_fut = fut 149*da0073e9SAndroid Build Coastguard Worker for _ in range(20): 150*da0073e9SAndroid Build Coastguard Worker last_fut = last_fut.then(add_one) 151*da0073e9SAndroid Build Coastguard Worker futs.append(last_fut) 152*da0073e9SAndroid Build Coastguard Worker 153*da0073e9SAndroid Build Coastguard Worker fut.set_result(torch.ones(2, 2)) 154*da0073e9SAndroid Build Coastguard Worker 155*da0073e9SAndroid Build Coastguard Worker for i in range(len(futs)): 156*da0073e9SAndroid Build Coastguard Worker self.assertEqual(futs[i].wait(), torch.ones(2, 2) + i + 1) 157*da0073e9SAndroid Build Coastguard Worker 158*da0073e9SAndroid Build Coastguard Worker def _test_then_error(self, cb, errMsg): 159*da0073e9SAndroid Build Coastguard Worker fut = Future[int]() 160*da0073e9SAndroid Build Coastguard Worker then_fut = fut.then(cb) 161*da0073e9SAndroid Build Coastguard Worker 162*da0073e9SAndroid Build Coastguard Worker fut.set_result(5) 163*da0073e9SAndroid Build Coastguard Worker self.assertEqual(5, fut.wait()) 164*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, errMsg): 165*da0073e9SAndroid Build Coastguard Worker then_fut.wait() 166*da0073e9SAndroid Build Coastguard Worker 167*da0073e9SAndroid Build Coastguard Worker def test_then_wrong_arg(self): 168*da0073e9SAndroid Build Coastguard Worker 169*da0073e9SAndroid Build Coastguard Worker def wrong_arg(tensor): 170*da0073e9SAndroid Build Coastguard Worker return tensor + 1 171*da0073e9SAndroid Build Coastguard Worker 172*da0073e9SAndroid Build Coastguard Worker self._test_then_error(wrong_arg, "unsupported operand type.*Future.*int") 173*da0073e9SAndroid Build Coastguard Worker 174*da0073e9SAndroid Build Coastguard Worker def test_then_no_arg(self): 175*da0073e9SAndroid Build Coastguard Worker 176*da0073e9SAndroid Build Coastguard Worker def no_arg(): 177*da0073e9SAndroid Build Coastguard Worker return True 178*da0073e9SAndroid Build Coastguard Worker 179*da0073e9SAndroid Build Coastguard Worker self._test_then_error(no_arg, "takes 0 positional arguments but 1 was given") 180*da0073e9SAndroid Build Coastguard Worker 181*da0073e9SAndroid Build Coastguard Worker def test_then_raise(self): 182*da0073e9SAndroid Build Coastguard Worker 183*da0073e9SAndroid Build Coastguard Worker def raise_value_error(fut): 184*da0073e9SAndroid Build Coastguard Worker raise ValueError("Expected error") 185*da0073e9SAndroid Build Coastguard Worker 186*da0073e9SAndroid Build Coastguard Worker self._test_then_error(raise_value_error, "Expected error") 187*da0073e9SAndroid Build Coastguard Worker 188*da0073e9SAndroid Build Coastguard Worker def test_add_done_callback_simple(self): 189*da0073e9SAndroid Build Coastguard Worker callback_result = False 190*da0073e9SAndroid Build Coastguard Worker 191*da0073e9SAndroid Build Coastguard Worker def callback(fut): 192*da0073e9SAndroid Build Coastguard Worker nonlocal callback_result 193*da0073e9SAndroid Build Coastguard Worker fut.wait() 194*da0073e9SAndroid Build Coastguard Worker callback_result = True 195*da0073e9SAndroid Build Coastguard Worker 196*da0073e9SAndroid Build Coastguard Worker fut = Future[torch.Tensor]() 197*da0073e9SAndroid Build Coastguard Worker fut.add_done_callback(callback) 198*da0073e9SAndroid Build Coastguard Worker 199*da0073e9SAndroid Build Coastguard Worker self.assertFalse(callback_result) 200*da0073e9SAndroid Build Coastguard Worker fut.set_result(torch.ones(2, 2)) 201*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fut.wait(), torch.ones(2, 2)) 202*da0073e9SAndroid Build Coastguard Worker self.assertTrue(callback_result) 203*da0073e9SAndroid Build Coastguard Worker 204*da0073e9SAndroid Build Coastguard Worker def test_add_done_callback_maintains_callback_order(self): 205*da0073e9SAndroid Build Coastguard Worker callback_result = 0 206*da0073e9SAndroid Build Coastguard Worker 207*da0073e9SAndroid Build Coastguard Worker def callback_set1(fut): 208*da0073e9SAndroid Build Coastguard Worker nonlocal callback_result 209*da0073e9SAndroid Build Coastguard Worker fut.wait() 210*da0073e9SAndroid Build Coastguard Worker callback_result = 1 211*da0073e9SAndroid Build Coastguard Worker 212*da0073e9SAndroid Build Coastguard Worker def callback_set2(fut): 213*da0073e9SAndroid Build Coastguard Worker nonlocal callback_result 214*da0073e9SAndroid Build Coastguard Worker fut.wait() 215*da0073e9SAndroid Build Coastguard Worker callback_result = 2 216*da0073e9SAndroid Build Coastguard Worker 217*da0073e9SAndroid Build Coastguard Worker fut = Future[torch.Tensor]() 218*da0073e9SAndroid Build Coastguard Worker fut.add_done_callback(callback_set1) 219*da0073e9SAndroid Build Coastguard Worker fut.add_done_callback(callback_set2) 220*da0073e9SAndroid Build Coastguard Worker 221*da0073e9SAndroid Build Coastguard Worker fut.set_result(torch.ones(2, 2)) 222*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fut.wait(), torch.ones(2, 2)) 223*da0073e9SAndroid Build Coastguard Worker # set2 called last, callback_result = 2 224*da0073e9SAndroid Build Coastguard Worker self.assertEqual(callback_result, 2) 225*da0073e9SAndroid Build Coastguard Worker 226*da0073e9SAndroid Build Coastguard Worker def _test_add_done_callback_error_ignored(self, cb): 227*da0073e9SAndroid Build Coastguard Worker fut = Future[int]() 228*da0073e9SAndroid Build Coastguard Worker fut.add_done_callback(cb) 229*da0073e9SAndroid Build Coastguard Worker 230*da0073e9SAndroid Build Coastguard Worker fut.set_result(5) 231*da0073e9SAndroid Build Coastguard Worker # error msg logged to stdout 232*da0073e9SAndroid Build Coastguard Worker self.assertEqual(5, fut.wait()) 233*da0073e9SAndroid Build Coastguard Worker 234*da0073e9SAndroid Build Coastguard Worker def test_add_done_callback_error_is_ignored(self): 235*da0073e9SAndroid Build Coastguard Worker 236*da0073e9SAndroid Build Coastguard Worker def raise_value_error(fut): 237*da0073e9SAndroid Build Coastguard Worker raise ValueError("Expected error") 238*da0073e9SAndroid Build Coastguard Worker 239*da0073e9SAndroid Build Coastguard Worker self._test_add_done_callback_error_ignored(raise_value_error) 240*da0073e9SAndroid Build Coastguard Worker 241*da0073e9SAndroid Build Coastguard Worker def test_add_done_callback_no_arg_error_is_ignored(self): 242*da0073e9SAndroid Build Coastguard Worker 243*da0073e9SAndroid Build Coastguard Worker def no_arg(): 244*da0073e9SAndroid Build Coastguard Worker return True 245*da0073e9SAndroid Build Coastguard Worker 246*da0073e9SAndroid Build Coastguard Worker # Adding another level of function indirection here on purpose. 247*da0073e9SAndroid Build Coastguard Worker # Otherwise mypy will pick up on no_arg having an incompatible type and fail CI 248*da0073e9SAndroid Build Coastguard Worker self._test_add_done_callback_error_ignored(no_arg) 249*da0073e9SAndroid Build Coastguard Worker 250*da0073e9SAndroid Build Coastguard Worker def test_interleaving_then_and_add_done_callback_maintains_callback_order(self): 251*da0073e9SAndroid Build Coastguard Worker callback_result = 0 252*da0073e9SAndroid Build Coastguard Worker 253*da0073e9SAndroid Build Coastguard Worker def callback_set1(fut): 254*da0073e9SAndroid Build Coastguard Worker nonlocal callback_result 255*da0073e9SAndroid Build Coastguard Worker fut.wait() 256*da0073e9SAndroid Build Coastguard Worker callback_result = 1 257*da0073e9SAndroid Build Coastguard Worker 258*da0073e9SAndroid Build Coastguard Worker def callback_set2(fut): 259*da0073e9SAndroid Build Coastguard Worker nonlocal callback_result 260*da0073e9SAndroid Build Coastguard Worker fut.wait() 261*da0073e9SAndroid Build Coastguard Worker callback_result = 2 262*da0073e9SAndroid Build Coastguard Worker 263*da0073e9SAndroid Build Coastguard Worker def callback_then(fut): 264*da0073e9SAndroid Build Coastguard Worker nonlocal callback_result 265*da0073e9SAndroid Build Coastguard Worker return fut.wait() + callback_result 266*da0073e9SAndroid Build Coastguard Worker 267*da0073e9SAndroid Build Coastguard Worker fut = Future[torch.Tensor]() 268*da0073e9SAndroid Build Coastguard Worker fut.add_done_callback(callback_set1) 269*da0073e9SAndroid Build Coastguard Worker then_fut = fut.then(callback_then) 270*da0073e9SAndroid Build Coastguard Worker fut.add_done_callback(callback_set2) 271*da0073e9SAndroid Build Coastguard Worker 272*da0073e9SAndroid Build Coastguard Worker fut.set_result(torch.ones(2, 2)) 273*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fut.wait(), torch.ones(2, 2)) 274*da0073e9SAndroid Build Coastguard Worker # then_fut's callback is called with callback_result = 1 275*da0073e9SAndroid Build Coastguard Worker self.assertEqual(then_fut.wait(), torch.ones(2, 2) + 1) 276*da0073e9SAndroid Build Coastguard Worker # set2 called last, callback_result = 2 277*da0073e9SAndroid Build Coastguard Worker self.assertEqual(callback_result, 2) 278*da0073e9SAndroid Build Coastguard Worker 279*da0073e9SAndroid Build Coastguard Worker def test_interleaving_then_and_add_done_callback_propagates_error(self): 280*da0073e9SAndroid Build Coastguard Worker def raise_value_error(fut): 281*da0073e9SAndroid Build Coastguard Worker raise ValueError("Expected error") 282*da0073e9SAndroid Build Coastguard Worker 283*da0073e9SAndroid Build Coastguard Worker fut = Future[torch.Tensor]() 284*da0073e9SAndroid Build Coastguard Worker then_fut = fut.then(raise_value_error) 285*da0073e9SAndroid Build Coastguard Worker fut.add_done_callback(raise_value_error) 286*da0073e9SAndroid Build Coastguard Worker fut.set_result(torch.ones(2, 2)) 287*da0073e9SAndroid Build Coastguard Worker 288*da0073e9SAndroid Build Coastguard Worker # error from add_done_callback's callback is swallowed 289*da0073e9SAndroid Build Coastguard Worker # error from then's callback is not 290*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fut.wait(), torch.ones(2, 2)) 291*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Expected error"): 292*da0073e9SAndroid Build Coastguard Worker then_fut.wait() 293*da0073e9SAndroid Build Coastguard Worker 294*da0073e9SAndroid Build Coastguard Worker def test_collect_all(self): 295*da0073e9SAndroid Build Coastguard Worker fut1 = Future[int]() 296*da0073e9SAndroid Build Coastguard Worker fut2 = Future[int]() 297*da0073e9SAndroid Build Coastguard Worker fut_all = torch.futures.collect_all([fut1, fut2]) 298*da0073e9SAndroid Build Coastguard Worker 299*da0073e9SAndroid Build Coastguard Worker def slow_in_thread(fut, value): 300*da0073e9SAndroid Build Coastguard Worker time.sleep(0.1) 301*da0073e9SAndroid Build Coastguard Worker fut.set_result(value) 302*da0073e9SAndroid Build Coastguard Worker 303*da0073e9SAndroid Build Coastguard Worker t = threading.Thread(target=slow_in_thread, args=(fut1, 1)) 304*da0073e9SAndroid Build Coastguard Worker fut2.set_result(2) 305*da0073e9SAndroid Build Coastguard Worker t.start() 306*da0073e9SAndroid Build Coastguard Worker 307*da0073e9SAndroid Build Coastguard Worker res = fut_all.wait() 308*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res[0].wait(), 1) 309*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res[1].wait(), 2) 310*da0073e9SAndroid Build Coastguard Worker t.join() 311*da0073e9SAndroid Build Coastguard Worker 312*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(IS_WINDOWS, "TODO: need to fix this testcase for Windows") 313*da0073e9SAndroid Build Coastguard Worker def test_wait_all(self): 314*da0073e9SAndroid Build Coastguard Worker fut1 = Future[int]() 315*da0073e9SAndroid Build Coastguard Worker fut2 = Future[int]() 316*da0073e9SAndroid Build Coastguard Worker 317*da0073e9SAndroid Build Coastguard Worker # No error version 318*da0073e9SAndroid Build Coastguard Worker fut1.set_result(1) 319*da0073e9SAndroid Build Coastguard Worker fut2.set_result(2) 320*da0073e9SAndroid Build Coastguard Worker res = torch.futures.wait_all([fut1, fut2]) 321*da0073e9SAndroid Build Coastguard Worker print(res) 322*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res, [1, 2]) 323*da0073e9SAndroid Build Coastguard Worker 324*da0073e9SAndroid Build Coastguard Worker # Version with an exception 325*da0073e9SAndroid Build Coastguard Worker def raise_in_fut(fut): 326*da0073e9SAndroid Build Coastguard Worker raise ValueError("Expected error") 327*da0073e9SAndroid Build Coastguard Worker fut3 = fut1.then(raise_in_fut) 328*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Expected error"): 329*da0073e9SAndroid Build Coastguard Worker torch.futures.wait_all([fut3, fut2]) 330*da0073e9SAndroid Build Coastguard Worker 331*da0073e9SAndroid Build Coastguard Worker def test_wait_none(self): 332*da0073e9SAndroid Build Coastguard Worker fut1 = Future[int]() 333*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Future can't be None"): 334*da0073e9SAndroid Build Coastguard Worker torch.jit.wait(None) 335*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Future can't be None"): 336*da0073e9SAndroid Build Coastguard Worker torch.futures.wait_all((None,)) # type: ignore[arg-type] 337*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Future can't be None"): 338*da0073e9SAndroid Build Coastguard Worker torch.futures.collect_all((fut1, None,)) # type: ignore[arg-type] 339*da0073e9SAndroid Build Coastguard Worker 340*da0073e9SAndroid Build Coastguard Workerif __name__ == '__main__': 341*da0073e9SAndroid Build Coastguard Worker run_tests() 342