xref: /aosp_15_r20/external/pytorch/test/test_futures.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs
2*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: unknown"]
3*da0073e9SAndroid Build Coastguard Worker
4*da0073e9SAndroid Build Coastguard Workerimport threading
5*da0073e9SAndroid Build Coastguard Workerimport time
6*da0073e9SAndroid Build Coastguard Workerimport torch
7*da0073e9SAndroid Build Coastguard Workerimport unittest
8*da0073e9SAndroid Build Coastguard Workerfrom torch.futures import Future
9*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import IS_WINDOWS, TestCase, TemporaryFileName, run_tests
10*da0073e9SAndroid Build Coastguard Workerfrom typing import TypeVar
11*da0073e9SAndroid Build Coastguard Worker
12*da0073e9SAndroid Build Coastguard WorkerT = TypeVar("T")
13*da0073e9SAndroid Build Coastguard Worker
14*da0073e9SAndroid Build Coastguard Worker
15*da0073e9SAndroid Build Coastguard Workerdef add_one(fut):
16*da0073e9SAndroid Build Coastguard Worker    return fut.wait() + 1
17*da0073e9SAndroid Build Coastguard Worker
18*da0073e9SAndroid Build Coastguard Worker
19*da0073e9SAndroid Build Coastguard Workerclass TestFuture(TestCase):
20*da0073e9SAndroid Build Coastguard Worker    def test_set_exception(self) -> None:
21*da0073e9SAndroid Build Coastguard Worker        # This test is to ensure errors can propagate across futures.
22*da0073e9SAndroid Build Coastguard Worker        error_msg = "Intentional Value Error"
23*da0073e9SAndroid Build Coastguard Worker        value_error = ValueError(error_msg)
24*da0073e9SAndroid Build Coastguard Worker
25*da0073e9SAndroid Build Coastguard Worker        f = Future[T]()  # type: ignore[valid-type]
26*da0073e9SAndroid Build Coastguard Worker        # Set exception
27*da0073e9SAndroid Build Coastguard Worker        f.set_exception(value_error)
28*da0073e9SAndroid Build Coastguard Worker        # Exception should throw on wait
29*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, "Intentional"):
30*da0073e9SAndroid Build Coastguard Worker            f.wait()
31*da0073e9SAndroid Build Coastguard Worker
32*da0073e9SAndroid Build Coastguard Worker        # Exception should also throw on value
33*da0073e9SAndroid Build Coastguard Worker        f = Future[T]()  # type: ignore[valid-type]
34*da0073e9SAndroid Build Coastguard Worker        f.set_exception(value_error)
35*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, "Intentional"):
36*da0073e9SAndroid Build Coastguard Worker            f.value()
37*da0073e9SAndroid Build Coastguard Worker
38*da0073e9SAndroid Build Coastguard Worker        def cb(fut):
39*da0073e9SAndroid Build Coastguard Worker            fut.value()
40*da0073e9SAndroid Build Coastguard Worker
41*da0073e9SAndroid Build Coastguard Worker        f = Future[T]()  # type: ignore[valid-type]
42*da0073e9SAndroid Build Coastguard Worker        f.set_exception(value_error)
43*da0073e9SAndroid Build Coastguard Worker
44*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "Got the following error"):
45*da0073e9SAndroid Build Coastguard Worker            cb_fut = f.then(cb)
46*da0073e9SAndroid Build Coastguard Worker            cb_fut.wait()
47*da0073e9SAndroid Build Coastguard Worker
48*da0073e9SAndroid Build Coastguard Worker    def test_set_exception_multithreading(self) -> None:
49*da0073e9SAndroid Build Coastguard Worker        # Ensure errors can propagate when one thread waits on future result
50*da0073e9SAndroid Build Coastguard Worker        # and the other sets it with an error.
51*da0073e9SAndroid Build Coastguard Worker        error_msg = "Intentional Value Error"
52*da0073e9SAndroid Build Coastguard Worker        value_error = ValueError(error_msg)
53*da0073e9SAndroid Build Coastguard Worker
54*da0073e9SAndroid Build Coastguard Worker        def wait_future(f):
55*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(ValueError, "Intentional"):
56*da0073e9SAndroid Build Coastguard Worker                f.wait()
57*da0073e9SAndroid Build Coastguard Worker
58*da0073e9SAndroid Build Coastguard Worker        f = Future[T]()  # type: ignore[valid-type]
59*da0073e9SAndroid Build Coastguard Worker        t = threading.Thread(target=wait_future, args=(f, ))
60*da0073e9SAndroid Build Coastguard Worker        t.start()
61*da0073e9SAndroid Build Coastguard Worker        f.set_exception(value_error)
62*da0073e9SAndroid Build Coastguard Worker        t.join()
63*da0073e9SAndroid Build Coastguard Worker
64*da0073e9SAndroid Build Coastguard Worker        def cb(fut):
65*da0073e9SAndroid Build Coastguard Worker            fut.value()
66*da0073e9SAndroid Build Coastguard Worker
67*da0073e9SAndroid Build Coastguard Worker        def then_future(f):
68*da0073e9SAndroid Build Coastguard Worker            fut = f.then(cb)
69*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, "Got the following error"):
70*da0073e9SAndroid Build Coastguard Worker                fut.wait()
71*da0073e9SAndroid Build Coastguard Worker
72*da0073e9SAndroid Build Coastguard Worker        f = Future[T]()  # type: ignore[valid-type]
73*da0073e9SAndroid Build Coastguard Worker        t = threading.Thread(target=then_future, args=(f, ))
74*da0073e9SAndroid Build Coastguard Worker        t.start()
75*da0073e9SAndroid Build Coastguard Worker        f.set_exception(value_error)
76*da0073e9SAndroid Build Coastguard Worker        t.join()
77*da0073e9SAndroid Build Coastguard Worker
78*da0073e9SAndroid Build Coastguard Worker    def test_done(self) -> None:
79*da0073e9SAndroid Build Coastguard Worker        f = Future[torch.Tensor]()
80*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(f.done())
81*da0073e9SAndroid Build Coastguard Worker
82*da0073e9SAndroid Build Coastguard Worker        f.set_result(torch.ones(2, 2))
83*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(f.done())
84*da0073e9SAndroid Build Coastguard Worker
85*da0073e9SAndroid Build Coastguard Worker    def test_done_exception(self) -> None:
86*da0073e9SAndroid Build Coastguard Worker        err_msg = "Intentional Value Error"
87*da0073e9SAndroid Build Coastguard Worker
88*da0073e9SAndroid Build Coastguard Worker        def raise_exception(unused_future):
89*da0073e9SAndroid Build Coastguard Worker            raise RuntimeError(err_msg)
90*da0073e9SAndroid Build Coastguard Worker
91*da0073e9SAndroid Build Coastguard Worker        f1 = Future[torch.Tensor]()
92*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(f1.done())
93*da0073e9SAndroid Build Coastguard Worker        f1.set_result(torch.ones(2, 2))
94*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(f1.done())
95*da0073e9SAndroid Build Coastguard Worker
96*da0073e9SAndroid Build Coastguard Worker        f2 = f1.then(raise_exception)
97*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(f2.done())
98*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, err_msg):
99*da0073e9SAndroid Build Coastguard Worker            f2.wait()
100*da0073e9SAndroid Build Coastguard Worker
101*da0073e9SAndroid Build Coastguard Worker    def test_wait(self) -> None:
102*da0073e9SAndroid Build Coastguard Worker        f = Future[torch.Tensor]()
103*da0073e9SAndroid Build Coastguard Worker        f.set_result(torch.ones(2, 2))
104*da0073e9SAndroid Build Coastguard Worker
105*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(f.wait(), torch.ones(2, 2))
106*da0073e9SAndroid Build Coastguard Worker
107*da0073e9SAndroid Build Coastguard Worker    def test_wait_multi_thread(self) -> None:
108*da0073e9SAndroid Build Coastguard Worker
109*da0073e9SAndroid Build Coastguard Worker        def slow_set_future(fut, value):
110*da0073e9SAndroid Build Coastguard Worker            time.sleep(0.5)
111*da0073e9SAndroid Build Coastguard Worker            fut.set_result(value)
112*da0073e9SAndroid Build Coastguard Worker
113*da0073e9SAndroid Build Coastguard Worker        f = Future[torch.Tensor]()
114*da0073e9SAndroid Build Coastguard Worker
115*da0073e9SAndroid Build Coastguard Worker        t = threading.Thread(target=slow_set_future, args=(f, torch.ones(2, 2)))
116*da0073e9SAndroid Build Coastguard Worker        t.start()
117*da0073e9SAndroid Build Coastguard Worker
118*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(f.wait(), torch.ones(2, 2))
119*da0073e9SAndroid Build Coastguard Worker        t.join()
120*da0073e9SAndroid Build Coastguard Worker
121*da0073e9SAndroid Build Coastguard Worker    def test_mark_future_twice(self) -> None:
122*da0073e9SAndroid Build Coastguard Worker        fut = Future[int]()
123*da0073e9SAndroid Build Coastguard Worker        fut.set_result(1)
124*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
125*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
126*da0073e9SAndroid Build Coastguard Worker            "Future can only be marked completed once"
127*da0073e9SAndroid Build Coastguard Worker        ):
128*da0073e9SAndroid Build Coastguard Worker            fut.set_result(1)
129*da0073e9SAndroid Build Coastguard Worker
130*da0073e9SAndroid Build Coastguard Worker    def test_pickle_future(self):
131*da0073e9SAndroid Build Coastguard Worker        fut = Future[int]()
132*da0073e9SAndroid Build Coastguard Worker        errMsg = "Can not pickle torch.futures.Future"
133*da0073e9SAndroid Build Coastguard Worker        with TemporaryFileName() as fname:
134*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, errMsg):
135*da0073e9SAndroid Build Coastguard Worker                torch.save(fut, fname)
136*da0073e9SAndroid Build Coastguard Worker
137*da0073e9SAndroid Build Coastguard Worker    def test_then(self):
138*da0073e9SAndroid Build Coastguard Worker        fut = Future[torch.Tensor]()
139*da0073e9SAndroid Build Coastguard Worker        then_fut = fut.then(lambda x: x.wait() + 1)
140*da0073e9SAndroid Build Coastguard Worker
141*da0073e9SAndroid Build Coastguard Worker        fut.set_result(torch.ones(2, 2))
142*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(fut.wait(), torch.ones(2, 2))
143*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(then_fut.wait(), torch.ones(2, 2) + 1)
144*da0073e9SAndroid Build Coastguard Worker
145*da0073e9SAndroid Build Coastguard Worker    def test_chained_then(self):
146*da0073e9SAndroid Build Coastguard Worker        fut = Future[torch.Tensor]()
147*da0073e9SAndroid Build Coastguard Worker        futs = []
148*da0073e9SAndroid Build Coastguard Worker        last_fut = fut
149*da0073e9SAndroid Build Coastguard Worker        for _ in range(20):
150*da0073e9SAndroid Build Coastguard Worker            last_fut = last_fut.then(add_one)
151*da0073e9SAndroid Build Coastguard Worker            futs.append(last_fut)
152*da0073e9SAndroid Build Coastguard Worker
153*da0073e9SAndroid Build Coastguard Worker        fut.set_result(torch.ones(2, 2))
154*da0073e9SAndroid Build Coastguard Worker
155*da0073e9SAndroid Build Coastguard Worker        for i in range(len(futs)):
156*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(futs[i].wait(), torch.ones(2, 2) + i + 1)
157*da0073e9SAndroid Build Coastguard Worker
158*da0073e9SAndroid Build Coastguard Worker    def _test_then_error(self, cb, errMsg):
159*da0073e9SAndroid Build Coastguard Worker        fut = Future[int]()
160*da0073e9SAndroid Build Coastguard Worker        then_fut = fut.then(cb)
161*da0073e9SAndroid Build Coastguard Worker
162*da0073e9SAndroid Build Coastguard Worker        fut.set_result(5)
163*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(5, fut.wait())
164*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, errMsg):
165*da0073e9SAndroid Build Coastguard Worker            then_fut.wait()
166*da0073e9SAndroid Build Coastguard Worker
167*da0073e9SAndroid Build Coastguard Worker    def test_then_wrong_arg(self):
168*da0073e9SAndroid Build Coastguard Worker
169*da0073e9SAndroid Build Coastguard Worker        def wrong_arg(tensor):
170*da0073e9SAndroid Build Coastguard Worker            return tensor + 1
171*da0073e9SAndroid Build Coastguard Worker
172*da0073e9SAndroid Build Coastguard Worker        self._test_then_error(wrong_arg, "unsupported operand type.*Future.*int")
173*da0073e9SAndroid Build Coastguard Worker
174*da0073e9SAndroid Build Coastguard Worker    def test_then_no_arg(self):
175*da0073e9SAndroid Build Coastguard Worker
176*da0073e9SAndroid Build Coastguard Worker        def no_arg():
177*da0073e9SAndroid Build Coastguard Worker            return True
178*da0073e9SAndroid Build Coastguard Worker
179*da0073e9SAndroid Build Coastguard Worker        self._test_then_error(no_arg, "takes 0 positional arguments but 1 was given")
180*da0073e9SAndroid Build Coastguard Worker
181*da0073e9SAndroid Build Coastguard Worker    def test_then_raise(self):
182*da0073e9SAndroid Build Coastguard Worker
183*da0073e9SAndroid Build Coastguard Worker        def raise_value_error(fut):
184*da0073e9SAndroid Build Coastguard Worker            raise ValueError("Expected error")
185*da0073e9SAndroid Build Coastguard Worker
186*da0073e9SAndroid Build Coastguard Worker        self._test_then_error(raise_value_error, "Expected error")
187*da0073e9SAndroid Build Coastguard Worker
188*da0073e9SAndroid Build Coastguard Worker    def test_add_done_callback_simple(self):
189*da0073e9SAndroid Build Coastguard Worker        callback_result = False
190*da0073e9SAndroid Build Coastguard Worker
191*da0073e9SAndroid Build Coastguard Worker        def callback(fut):
192*da0073e9SAndroid Build Coastguard Worker            nonlocal callback_result
193*da0073e9SAndroid Build Coastguard Worker            fut.wait()
194*da0073e9SAndroid Build Coastguard Worker            callback_result = True
195*da0073e9SAndroid Build Coastguard Worker
196*da0073e9SAndroid Build Coastguard Worker        fut = Future[torch.Tensor]()
197*da0073e9SAndroid Build Coastguard Worker        fut.add_done_callback(callback)
198*da0073e9SAndroid Build Coastguard Worker
199*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(callback_result)
200*da0073e9SAndroid Build Coastguard Worker        fut.set_result(torch.ones(2, 2))
201*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(fut.wait(), torch.ones(2, 2))
202*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(callback_result)
203*da0073e9SAndroid Build Coastguard Worker
204*da0073e9SAndroid Build Coastguard Worker    def test_add_done_callback_maintains_callback_order(self):
205*da0073e9SAndroid Build Coastguard Worker        callback_result = 0
206*da0073e9SAndroid Build Coastguard Worker
207*da0073e9SAndroid Build Coastguard Worker        def callback_set1(fut):
208*da0073e9SAndroid Build Coastguard Worker            nonlocal callback_result
209*da0073e9SAndroid Build Coastguard Worker            fut.wait()
210*da0073e9SAndroid Build Coastguard Worker            callback_result = 1
211*da0073e9SAndroid Build Coastguard Worker
212*da0073e9SAndroid Build Coastguard Worker        def callback_set2(fut):
213*da0073e9SAndroid Build Coastguard Worker            nonlocal callback_result
214*da0073e9SAndroid Build Coastguard Worker            fut.wait()
215*da0073e9SAndroid Build Coastguard Worker            callback_result = 2
216*da0073e9SAndroid Build Coastguard Worker
217*da0073e9SAndroid Build Coastguard Worker        fut = Future[torch.Tensor]()
218*da0073e9SAndroid Build Coastguard Worker        fut.add_done_callback(callback_set1)
219*da0073e9SAndroid Build Coastguard Worker        fut.add_done_callback(callback_set2)
220*da0073e9SAndroid Build Coastguard Worker
221*da0073e9SAndroid Build Coastguard Worker        fut.set_result(torch.ones(2, 2))
222*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(fut.wait(), torch.ones(2, 2))
223*da0073e9SAndroid Build Coastguard Worker        # set2 called last, callback_result = 2
224*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(callback_result, 2)
225*da0073e9SAndroid Build Coastguard Worker
226*da0073e9SAndroid Build Coastguard Worker    def _test_add_done_callback_error_ignored(self, cb):
227*da0073e9SAndroid Build Coastguard Worker        fut = Future[int]()
228*da0073e9SAndroid Build Coastguard Worker        fut.add_done_callback(cb)
229*da0073e9SAndroid Build Coastguard Worker
230*da0073e9SAndroid Build Coastguard Worker        fut.set_result(5)
231*da0073e9SAndroid Build Coastguard Worker        # error msg logged to stdout
232*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(5, fut.wait())
233*da0073e9SAndroid Build Coastguard Worker
234*da0073e9SAndroid Build Coastguard Worker    def test_add_done_callback_error_is_ignored(self):
235*da0073e9SAndroid Build Coastguard Worker
236*da0073e9SAndroid Build Coastguard Worker        def raise_value_error(fut):
237*da0073e9SAndroid Build Coastguard Worker            raise ValueError("Expected error")
238*da0073e9SAndroid Build Coastguard Worker
239*da0073e9SAndroid Build Coastguard Worker        self._test_add_done_callback_error_ignored(raise_value_error)
240*da0073e9SAndroid Build Coastguard Worker
241*da0073e9SAndroid Build Coastguard Worker    def test_add_done_callback_no_arg_error_is_ignored(self):
242*da0073e9SAndroid Build Coastguard Worker
243*da0073e9SAndroid Build Coastguard Worker        def no_arg():
244*da0073e9SAndroid Build Coastguard Worker            return True
245*da0073e9SAndroid Build Coastguard Worker
246*da0073e9SAndroid Build Coastguard Worker        # Adding another level of function indirection here on purpose.
247*da0073e9SAndroid Build Coastguard Worker        # Otherwise mypy will pick up on no_arg having an incompatible type and fail CI
248*da0073e9SAndroid Build Coastguard Worker        self._test_add_done_callback_error_ignored(no_arg)
249*da0073e9SAndroid Build Coastguard Worker
250*da0073e9SAndroid Build Coastguard Worker    def test_interleaving_then_and_add_done_callback_maintains_callback_order(self):
251*da0073e9SAndroid Build Coastguard Worker        callback_result = 0
252*da0073e9SAndroid Build Coastguard Worker
253*da0073e9SAndroid Build Coastguard Worker        def callback_set1(fut):
254*da0073e9SAndroid Build Coastguard Worker            nonlocal callback_result
255*da0073e9SAndroid Build Coastguard Worker            fut.wait()
256*da0073e9SAndroid Build Coastguard Worker            callback_result = 1
257*da0073e9SAndroid Build Coastguard Worker
258*da0073e9SAndroid Build Coastguard Worker        def callback_set2(fut):
259*da0073e9SAndroid Build Coastguard Worker            nonlocal callback_result
260*da0073e9SAndroid Build Coastguard Worker            fut.wait()
261*da0073e9SAndroid Build Coastguard Worker            callback_result = 2
262*da0073e9SAndroid Build Coastguard Worker
263*da0073e9SAndroid Build Coastguard Worker        def callback_then(fut):
264*da0073e9SAndroid Build Coastguard Worker            nonlocal callback_result
265*da0073e9SAndroid Build Coastguard Worker            return fut.wait() + callback_result
266*da0073e9SAndroid Build Coastguard Worker
267*da0073e9SAndroid Build Coastguard Worker        fut = Future[torch.Tensor]()
268*da0073e9SAndroid Build Coastguard Worker        fut.add_done_callback(callback_set1)
269*da0073e9SAndroid Build Coastguard Worker        then_fut = fut.then(callback_then)
270*da0073e9SAndroid Build Coastguard Worker        fut.add_done_callback(callback_set2)
271*da0073e9SAndroid Build Coastguard Worker
272*da0073e9SAndroid Build Coastguard Worker        fut.set_result(torch.ones(2, 2))
273*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(fut.wait(), torch.ones(2, 2))
274*da0073e9SAndroid Build Coastguard Worker        # then_fut's callback is called with callback_result = 1
275*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(then_fut.wait(), torch.ones(2, 2) + 1)
276*da0073e9SAndroid Build Coastguard Worker        # set2 called last, callback_result = 2
277*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(callback_result, 2)
278*da0073e9SAndroid Build Coastguard Worker
279*da0073e9SAndroid Build Coastguard Worker    def test_interleaving_then_and_add_done_callback_propagates_error(self):
280*da0073e9SAndroid Build Coastguard Worker        def raise_value_error(fut):
281*da0073e9SAndroid Build Coastguard Worker            raise ValueError("Expected error")
282*da0073e9SAndroid Build Coastguard Worker
283*da0073e9SAndroid Build Coastguard Worker        fut = Future[torch.Tensor]()
284*da0073e9SAndroid Build Coastguard Worker        then_fut = fut.then(raise_value_error)
285*da0073e9SAndroid Build Coastguard Worker        fut.add_done_callback(raise_value_error)
286*da0073e9SAndroid Build Coastguard Worker        fut.set_result(torch.ones(2, 2))
287*da0073e9SAndroid Build Coastguard Worker
288*da0073e9SAndroid Build Coastguard Worker        # error from add_done_callback's callback is swallowed
289*da0073e9SAndroid Build Coastguard Worker        # error from then's callback is not
290*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(fut.wait(), torch.ones(2, 2))
291*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "Expected error"):
292*da0073e9SAndroid Build Coastguard Worker            then_fut.wait()
293*da0073e9SAndroid Build Coastguard Worker
294*da0073e9SAndroid Build Coastguard Worker    def test_collect_all(self):
295*da0073e9SAndroid Build Coastguard Worker        fut1 = Future[int]()
296*da0073e9SAndroid Build Coastguard Worker        fut2 = Future[int]()
297*da0073e9SAndroid Build Coastguard Worker        fut_all = torch.futures.collect_all([fut1, fut2])
298*da0073e9SAndroid Build Coastguard Worker
299*da0073e9SAndroid Build Coastguard Worker        def slow_in_thread(fut, value):
300*da0073e9SAndroid Build Coastguard Worker            time.sleep(0.1)
301*da0073e9SAndroid Build Coastguard Worker            fut.set_result(value)
302*da0073e9SAndroid Build Coastguard Worker
303*da0073e9SAndroid Build Coastguard Worker        t = threading.Thread(target=slow_in_thread, args=(fut1, 1))
304*da0073e9SAndroid Build Coastguard Worker        fut2.set_result(2)
305*da0073e9SAndroid Build Coastguard Worker        t.start()
306*da0073e9SAndroid Build Coastguard Worker
307*da0073e9SAndroid Build Coastguard Worker        res = fut_all.wait()
308*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res[0].wait(), 1)
309*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res[1].wait(), 2)
310*da0073e9SAndroid Build Coastguard Worker        t.join()
311*da0073e9SAndroid Build Coastguard Worker
312*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(IS_WINDOWS, "TODO: need to fix this testcase for Windows")
313*da0073e9SAndroid Build Coastguard Worker    def test_wait_all(self):
314*da0073e9SAndroid Build Coastguard Worker        fut1 = Future[int]()
315*da0073e9SAndroid Build Coastguard Worker        fut2 = Future[int]()
316*da0073e9SAndroid Build Coastguard Worker
317*da0073e9SAndroid Build Coastguard Worker        # No error version
318*da0073e9SAndroid Build Coastguard Worker        fut1.set_result(1)
319*da0073e9SAndroid Build Coastguard Worker        fut2.set_result(2)
320*da0073e9SAndroid Build Coastguard Worker        res = torch.futures.wait_all([fut1, fut2])
321*da0073e9SAndroid Build Coastguard Worker        print(res)
322*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res, [1, 2])
323*da0073e9SAndroid Build Coastguard Worker
324*da0073e9SAndroid Build Coastguard Worker        # Version with an exception
325*da0073e9SAndroid Build Coastguard Worker        def raise_in_fut(fut):
326*da0073e9SAndroid Build Coastguard Worker            raise ValueError("Expected error")
327*da0073e9SAndroid Build Coastguard Worker        fut3 = fut1.then(raise_in_fut)
328*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "Expected error"):
329*da0073e9SAndroid Build Coastguard Worker            torch.futures.wait_all([fut3, fut2])
330*da0073e9SAndroid Build Coastguard Worker
331*da0073e9SAndroid Build Coastguard Worker    def test_wait_none(self):
332*da0073e9SAndroid Build Coastguard Worker        fut1 = Future[int]()
333*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "Future can't be None"):
334*da0073e9SAndroid Build Coastguard Worker            torch.jit.wait(None)
335*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "Future can't be None"):
336*da0073e9SAndroid Build Coastguard Worker            torch.futures.wait_all((None,))  # type: ignore[arg-type]
337*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "Future can't be None"):
338*da0073e9SAndroid Build Coastguard Worker            torch.futures.collect_all((fut1, None,))  # type: ignore[arg-type]
339*da0073e9SAndroid Build Coastguard Worker
340*da0073e9SAndroid Build Coastguard Workerif __name__ == '__main__':
341*da0073e9SAndroid Build Coastguard Worker    run_tests()
342