1# mypy: allow-untyped-defs 2# Owner(s): ["module: unknown"] 3 4import threading 5import time 6import torch 7import unittest 8from torch.futures import Future 9from torch.testing._internal.common_utils import IS_WINDOWS, TestCase, TemporaryFileName, run_tests 10from typing import TypeVar 11 12T = TypeVar("T") 13 14 15def add_one(fut): 16 return fut.wait() + 1 17 18 19class TestFuture(TestCase): 20 def test_set_exception(self) -> None: 21 # This test is to ensure errors can propagate across futures. 22 error_msg = "Intentional Value Error" 23 value_error = ValueError(error_msg) 24 25 f = Future[T]() # type: ignore[valid-type] 26 # Set exception 27 f.set_exception(value_error) 28 # Exception should throw on wait 29 with self.assertRaisesRegex(ValueError, "Intentional"): 30 f.wait() 31 32 # Exception should also throw on value 33 f = Future[T]() # type: ignore[valid-type] 34 f.set_exception(value_error) 35 with self.assertRaisesRegex(ValueError, "Intentional"): 36 f.value() 37 38 def cb(fut): 39 fut.value() 40 41 f = Future[T]() # type: ignore[valid-type] 42 f.set_exception(value_error) 43 44 with self.assertRaisesRegex(RuntimeError, "Got the following error"): 45 cb_fut = f.then(cb) 46 cb_fut.wait() 47 48 def test_set_exception_multithreading(self) -> None: 49 # Ensure errors can propagate when one thread waits on future result 50 # and the other sets it with an error. 51 error_msg = "Intentional Value Error" 52 value_error = ValueError(error_msg) 53 54 def wait_future(f): 55 with self.assertRaisesRegex(ValueError, "Intentional"): 56 f.wait() 57 58 f = Future[T]() # type: ignore[valid-type] 59 t = threading.Thread(target=wait_future, args=(f, )) 60 t.start() 61 f.set_exception(value_error) 62 t.join() 63 64 def cb(fut): 65 fut.value() 66 67 def then_future(f): 68 fut = f.then(cb) 69 with self.assertRaisesRegex(RuntimeError, "Got the following error"): 70 fut.wait() 71 72 f = Future[T]() # type: ignore[valid-type] 73 t = threading.Thread(target=then_future, args=(f, )) 74 t.start() 75 f.set_exception(value_error) 76 t.join() 77 78 def test_done(self) -> None: 79 f = Future[torch.Tensor]() 80 self.assertFalse(f.done()) 81 82 f.set_result(torch.ones(2, 2)) 83 self.assertTrue(f.done()) 84 85 def test_done_exception(self) -> None: 86 err_msg = "Intentional Value Error" 87 88 def raise_exception(unused_future): 89 raise RuntimeError(err_msg) 90 91 f1 = Future[torch.Tensor]() 92 self.assertFalse(f1.done()) 93 f1.set_result(torch.ones(2, 2)) 94 self.assertTrue(f1.done()) 95 96 f2 = f1.then(raise_exception) 97 self.assertTrue(f2.done()) 98 with self.assertRaisesRegex(RuntimeError, err_msg): 99 f2.wait() 100 101 def test_wait(self) -> None: 102 f = Future[torch.Tensor]() 103 f.set_result(torch.ones(2, 2)) 104 105 self.assertEqual(f.wait(), torch.ones(2, 2)) 106 107 def test_wait_multi_thread(self) -> None: 108 109 def slow_set_future(fut, value): 110 time.sleep(0.5) 111 fut.set_result(value) 112 113 f = Future[torch.Tensor]() 114 115 t = threading.Thread(target=slow_set_future, args=(f, torch.ones(2, 2))) 116 t.start() 117 118 self.assertEqual(f.wait(), torch.ones(2, 2)) 119 t.join() 120 121 def test_mark_future_twice(self) -> None: 122 fut = Future[int]() 123 fut.set_result(1) 124 with self.assertRaisesRegex( 125 RuntimeError, 126 "Future can only be marked completed once" 127 ): 128 fut.set_result(1) 129 130 def test_pickle_future(self): 131 fut = Future[int]() 132 errMsg = "Can not pickle torch.futures.Future" 133 with TemporaryFileName() as fname: 134 with self.assertRaisesRegex(RuntimeError, errMsg): 135 torch.save(fut, fname) 136 137 def test_then(self): 138 fut = Future[torch.Tensor]() 139 then_fut = fut.then(lambda x: x.wait() + 1) 140 141 fut.set_result(torch.ones(2, 2)) 142 self.assertEqual(fut.wait(), torch.ones(2, 2)) 143 self.assertEqual(then_fut.wait(), torch.ones(2, 2) + 1) 144 145 def test_chained_then(self): 146 fut = Future[torch.Tensor]() 147 futs = [] 148 last_fut = fut 149 for _ in range(20): 150 last_fut = last_fut.then(add_one) 151 futs.append(last_fut) 152 153 fut.set_result(torch.ones(2, 2)) 154 155 for i in range(len(futs)): 156 self.assertEqual(futs[i].wait(), torch.ones(2, 2) + i + 1) 157 158 def _test_then_error(self, cb, errMsg): 159 fut = Future[int]() 160 then_fut = fut.then(cb) 161 162 fut.set_result(5) 163 self.assertEqual(5, fut.wait()) 164 with self.assertRaisesRegex(RuntimeError, errMsg): 165 then_fut.wait() 166 167 def test_then_wrong_arg(self): 168 169 def wrong_arg(tensor): 170 return tensor + 1 171 172 self._test_then_error(wrong_arg, "unsupported operand type.*Future.*int") 173 174 def test_then_no_arg(self): 175 176 def no_arg(): 177 return True 178 179 self._test_then_error(no_arg, "takes 0 positional arguments but 1 was given") 180 181 def test_then_raise(self): 182 183 def raise_value_error(fut): 184 raise ValueError("Expected error") 185 186 self._test_then_error(raise_value_error, "Expected error") 187 188 def test_add_done_callback_simple(self): 189 callback_result = False 190 191 def callback(fut): 192 nonlocal callback_result 193 fut.wait() 194 callback_result = True 195 196 fut = Future[torch.Tensor]() 197 fut.add_done_callback(callback) 198 199 self.assertFalse(callback_result) 200 fut.set_result(torch.ones(2, 2)) 201 self.assertEqual(fut.wait(), torch.ones(2, 2)) 202 self.assertTrue(callback_result) 203 204 def test_add_done_callback_maintains_callback_order(self): 205 callback_result = 0 206 207 def callback_set1(fut): 208 nonlocal callback_result 209 fut.wait() 210 callback_result = 1 211 212 def callback_set2(fut): 213 nonlocal callback_result 214 fut.wait() 215 callback_result = 2 216 217 fut = Future[torch.Tensor]() 218 fut.add_done_callback(callback_set1) 219 fut.add_done_callback(callback_set2) 220 221 fut.set_result(torch.ones(2, 2)) 222 self.assertEqual(fut.wait(), torch.ones(2, 2)) 223 # set2 called last, callback_result = 2 224 self.assertEqual(callback_result, 2) 225 226 def _test_add_done_callback_error_ignored(self, cb): 227 fut = Future[int]() 228 fut.add_done_callback(cb) 229 230 fut.set_result(5) 231 # error msg logged to stdout 232 self.assertEqual(5, fut.wait()) 233 234 def test_add_done_callback_error_is_ignored(self): 235 236 def raise_value_error(fut): 237 raise ValueError("Expected error") 238 239 self._test_add_done_callback_error_ignored(raise_value_error) 240 241 def test_add_done_callback_no_arg_error_is_ignored(self): 242 243 def no_arg(): 244 return True 245 246 # Adding another level of function indirection here on purpose. 247 # Otherwise mypy will pick up on no_arg having an incompatible type and fail CI 248 self._test_add_done_callback_error_ignored(no_arg) 249 250 def test_interleaving_then_and_add_done_callback_maintains_callback_order(self): 251 callback_result = 0 252 253 def callback_set1(fut): 254 nonlocal callback_result 255 fut.wait() 256 callback_result = 1 257 258 def callback_set2(fut): 259 nonlocal callback_result 260 fut.wait() 261 callback_result = 2 262 263 def callback_then(fut): 264 nonlocal callback_result 265 return fut.wait() + callback_result 266 267 fut = Future[torch.Tensor]() 268 fut.add_done_callback(callback_set1) 269 then_fut = fut.then(callback_then) 270 fut.add_done_callback(callback_set2) 271 272 fut.set_result(torch.ones(2, 2)) 273 self.assertEqual(fut.wait(), torch.ones(2, 2)) 274 # then_fut's callback is called with callback_result = 1 275 self.assertEqual(then_fut.wait(), torch.ones(2, 2) + 1) 276 # set2 called last, callback_result = 2 277 self.assertEqual(callback_result, 2) 278 279 def test_interleaving_then_and_add_done_callback_propagates_error(self): 280 def raise_value_error(fut): 281 raise ValueError("Expected error") 282 283 fut = Future[torch.Tensor]() 284 then_fut = fut.then(raise_value_error) 285 fut.add_done_callback(raise_value_error) 286 fut.set_result(torch.ones(2, 2)) 287 288 # error from add_done_callback's callback is swallowed 289 # error from then's callback is not 290 self.assertEqual(fut.wait(), torch.ones(2, 2)) 291 with self.assertRaisesRegex(RuntimeError, "Expected error"): 292 then_fut.wait() 293 294 def test_collect_all(self): 295 fut1 = Future[int]() 296 fut2 = Future[int]() 297 fut_all = torch.futures.collect_all([fut1, fut2]) 298 299 def slow_in_thread(fut, value): 300 time.sleep(0.1) 301 fut.set_result(value) 302 303 t = threading.Thread(target=slow_in_thread, args=(fut1, 1)) 304 fut2.set_result(2) 305 t.start() 306 307 res = fut_all.wait() 308 self.assertEqual(res[0].wait(), 1) 309 self.assertEqual(res[1].wait(), 2) 310 t.join() 311 312 @unittest.skipIf(IS_WINDOWS, "TODO: need to fix this testcase for Windows") 313 def test_wait_all(self): 314 fut1 = Future[int]() 315 fut2 = Future[int]() 316 317 # No error version 318 fut1.set_result(1) 319 fut2.set_result(2) 320 res = torch.futures.wait_all([fut1, fut2]) 321 print(res) 322 self.assertEqual(res, [1, 2]) 323 324 # Version with an exception 325 def raise_in_fut(fut): 326 raise ValueError("Expected error") 327 fut3 = fut1.then(raise_in_fut) 328 with self.assertRaisesRegex(RuntimeError, "Expected error"): 329 torch.futures.wait_all([fut3, fut2]) 330 331 def test_wait_none(self): 332 fut1 = Future[int]() 333 with self.assertRaisesRegex(RuntimeError, "Future can't be None"): 334 torch.jit.wait(None) 335 with self.assertRaisesRegex(RuntimeError, "Future can't be None"): 336 torch.futures.wait_all((None,)) # type: ignore[arg-type] 337 with self.assertRaisesRegex(RuntimeError, "Future can't be None"): 338 torch.futures.collect_all((fut1, None,)) # type: ignore[arg-type] 339 340if __name__ == '__main__': 341 run_tests() 342