xref: /aosp_15_r20/external/pytorch/test/test_futures.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2# Owner(s): ["module: unknown"]
3
4import threading
5import time
6import torch
7import unittest
8from torch.futures import Future
9from torch.testing._internal.common_utils import IS_WINDOWS, TestCase, TemporaryFileName, run_tests
10from typing import TypeVar
11
12T = TypeVar("T")
13
14
15def add_one(fut):
16    return fut.wait() + 1
17
18
19class TestFuture(TestCase):
20    def test_set_exception(self) -> None:
21        # This test is to ensure errors can propagate across futures.
22        error_msg = "Intentional Value Error"
23        value_error = ValueError(error_msg)
24
25        f = Future[T]()  # type: ignore[valid-type]
26        # Set exception
27        f.set_exception(value_error)
28        # Exception should throw on wait
29        with self.assertRaisesRegex(ValueError, "Intentional"):
30            f.wait()
31
32        # Exception should also throw on value
33        f = Future[T]()  # type: ignore[valid-type]
34        f.set_exception(value_error)
35        with self.assertRaisesRegex(ValueError, "Intentional"):
36            f.value()
37
38        def cb(fut):
39            fut.value()
40
41        f = Future[T]()  # type: ignore[valid-type]
42        f.set_exception(value_error)
43
44        with self.assertRaisesRegex(RuntimeError, "Got the following error"):
45            cb_fut = f.then(cb)
46            cb_fut.wait()
47
48    def test_set_exception_multithreading(self) -> None:
49        # Ensure errors can propagate when one thread waits on future result
50        # and the other sets it with an error.
51        error_msg = "Intentional Value Error"
52        value_error = ValueError(error_msg)
53
54        def wait_future(f):
55            with self.assertRaisesRegex(ValueError, "Intentional"):
56                f.wait()
57
58        f = Future[T]()  # type: ignore[valid-type]
59        t = threading.Thread(target=wait_future, args=(f, ))
60        t.start()
61        f.set_exception(value_error)
62        t.join()
63
64        def cb(fut):
65            fut.value()
66
67        def then_future(f):
68            fut = f.then(cb)
69            with self.assertRaisesRegex(RuntimeError, "Got the following error"):
70                fut.wait()
71
72        f = Future[T]()  # type: ignore[valid-type]
73        t = threading.Thread(target=then_future, args=(f, ))
74        t.start()
75        f.set_exception(value_error)
76        t.join()
77
78    def test_done(self) -> None:
79        f = Future[torch.Tensor]()
80        self.assertFalse(f.done())
81
82        f.set_result(torch.ones(2, 2))
83        self.assertTrue(f.done())
84
85    def test_done_exception(self) -> None:
86        err_msg = "Intentional Value Error"
87
88        def raise_exception(unused_future):
89            raise RuntimeError(err_msg)
90
91        f1 = Future[torch.Tensor]()
92        self.assertFalse(f1.done())
93        f1.set_result(torch.ones(2, 2))
94        self.assertTrue(f1.done())
95
96        f2 = f1.then(raise_exception)
97        self.assertTrue(f2.done())
98        with self.assertRaisesRegex(RuntimeError, err_msg):
99            f2.wait()
100
101    def test_wait(self) -> None:
102        f = Future[torch.Tensor]()
103        f.set_result(torch.ones(2, 2))
104
105        self.assertEqual(f.wait(), torch.ones(2, 2))
106
107    def test_wait_multi_thread(self) -> None:
108
109        def slow_set_future(fut, value):
110            time.sleep(0.5)
111            fut.set_result(value)
112
113        f = Future[torch.Tensor]()
114
115        t = threading.Thread(target=slow_set_future, args=(f, torch.ones(2, 2)))
116        t.start()
117
118        self.assertEqual(f.wait(), torch.ones(2, 2))
119        t.join()
120
121    def test_mark_future_twice(self) -> None:
122        fut = Future[int]()
123        fut.set_result(1)
124        with self.assertRaisesRegex(
125            RuntimeError,
126            "Future can only be marked completed once"
127        ):
128            fut.set_result(1)
129
130    def test_pickle_future(self):
131        fut = Future[int]()
132        errMsg = "Can not pickle torch.futures.Future"
133        with TemporaryFileName() as fname:
134            with self.assertRaisesRegex(RuntimeError, errMsg):
135                torch.save(fut, fname)
136
137    def test_then(self):
138        fut = Future[torch.Tensor]()
139        then_fut = fut.then(lambda x: x.wait() + 1)
140
141        fut.set_result(torch.ones(2, 2))
142        self.assertEqual(fut.wait(), torch.ones(2, 2))
143        self.assertEqual(then_fut.wait(), torch.ones(2, 2) + 1)
144
145    def test_chained_then(self):
146        fut = Future[torch.Tensor]()
147        futs = []
148        last_fut = fut
149        for _ in range(20):
150            last_fut = last_fut.then(add_one)
151            futs.append(last_fut)
152
153        fut.set_result(torch.ones(2, 2))
154
155        for i in range(len(futs)):
156            self.assertEqual(futs[i].wait(), torch.ones(2, 2) + i + 1)
157
158    def _test_then_error(self, cb, errMsg):
159        fut = Future[int]()
160        then_fut = fut.then(cb)
161
162        fut.set_result(5)
163        self.assertEqual(5, fut.wait())
164        with self.assertRaisesRegex(RuntimeError, errMsg):
165            then_fut.wait()
166
167    def test_then_wrong_arg(self):
168
169        def wrong_arg(tensor):
170            return tensor + 1
171
172        self._test_then_error(wrong_arg, "unsupported operand type.*Future.*int")
173
174    def test_then_no_arg(self):
175
176        def no_arg():
177            return True
178
179        self._test_then_error(no_arg, "takes 0 positional arguments but 1 was given")
180
181    def test_then_raise(self):
182
183        def raise_value_error(fut):
184            raise ValueError("Expected error")
185
186        self._test_then_error(raise_value_error, "Expected error")
187
188    def test_add_done_callback_simple(self):
189        callback_result = False
190
191        def callback(fut):
192            nonlocal callback_result
193            fut.wait()
194            callback_result = True
195
196        fut = Future[torch.Tensor]()
197        fut.add_done_callback(callback)
198
199        self.assertFalse(callback_result)
200        fut.set_result(torch.ones(2, 2))
201        self.assertEqual(fut.wait(), torch.ones(2, 2))
202        self.assertTrue(callback_result)
203
204    def test_add_done_callback_maintains_callback_order(self):
205        callback_result = 0
206
207        def callback_set1(fut):
208            nonlocal callback_result
209            fut.wait()
210            callback_result = 1
211
212        def callback_set2(fut):
213            nonlocal callback_result
214            fut.wait()
215            callback_result = 2
216
217        fut = Future[torch.Tensor]()
218        fut.add_done_callback(callback_set1)
219        fut.add_done_callback(callback_set2)
220
221        fut.set_result(torch.ones(2, 2))
222        self.assertEqual(fut.wait(), torch.ones(2, 2))
223        # set2 called last, callback_result = 2
224        self.assertEqual(callback_result, 2)
225
226    def _test_add_done_callback_error_ignored(self, cb):
227        fut = Future[int]()
228        fut.add_done_callback(cb)
229
230        fut.set_result(5)
231        # error msg logged to stdout
232        self.assertEqual(5, fut.wait())
233
234    def test_add_done_callback_error_is_ignored(self):
235
236        def raise_value_error(fut):
237            raise ValueError("Expected error")
238
239        self._test_add_done_callback_error_ignored(raise_value_error)
240
241    def test_add_done_callback_no_arg_error_is_ignored(self):
242
243        def no_arg():
244            return True
245
246        # Adding another level of function indirection here on purpose.
247        # Otherwise mypy will pick up on no_arg having an incompatible type and fail CI
248        self._test_add_done_callback_error_ignored(no_arg)
249
250    def test_interleaving_then_and_add_done_callback_maintains_callback_order(self):
251        callback_result = 0
252
253        def callback_set1(fut):
254            nonlocal callback_result
255            fut.wait()
256            callback_result = 1
257
258        def callback_set2(fut):
259            nonlocal callback_result
260            fut.wait()
261            callback_result = 2
262
263        def callback_then(fut):
264            nonlocal callback_result
265            return fut.wait() + callback_result
266
267        fut = Future[torch.Tensor]()
268        fut.add_done_callback(callback_set1)
269        then_fut = fut.then(callback_then)
270        fut.add_done_callback(callback_set2)
271
272        fut.set_result(torch.ones(2, 2))
273        self.assertEqual(fut.wait(), torch.ones(2, 2))
274        # then_fut's callback is called with callback_result = 1
275        self.assertEqual(then_fut.wait(), torch.ones(2, 2) + 1)
276        # set2 called last, callback_result = 2
277        self.assertEqual(callback_result, 2)
278
279    def test_interleaving_then_and_add_done_callback_propagates_error(self):
280        def raise_value_error(fut):
281            raise ValueError("Expected error")
282
283        fut = Future[torch.Tensor]()
284        then_fut = fut.then(raise_value_error)
285        fut.add_done_callback(raise_value_error)
286        fut.set_result(torch.ones(2, 2))
287
288        # error from add_done_callback's callback is swallowed
289        # error from then's callback is not
290        self.assertEqual(fut.wait(), torch.ones(2, 2))
291        with self.assertRaisesRegex(RuntimeError, "Expected error"):
292            then_fut.wait()
293
294    def test_collect_all(self):
295        fut1 = Future[int]()
296        fut2 = Future[int]()
297        fut_all = torch.futures.collect_all([fut1, fut2])
298
299        def slow_in_thread(fut, value):
300            time.sleep(0.1)
301            fut.set_result(value)
302
303        t = threading.Thread(target=slow_in_thread, args=(fut1, 1))
304        fut2.set_result(2)
305        t.start()
306
307        res = fut_all.wait()
308        self.assertEqual(res[0].wait(), 1)
309        self.assertEqual(res[1].wait(), 2)
310        t.join()
311
312    @unittest.skipIf(IS_WINDOWS, "TODO: need to fix this testcase for Windows")
313    def test_wait_all(self):
314        fut1 = Future[int]()
315        fut2 = Future[int]()
316
317        # No error version
318        fut1.set_result(1)
319        fut2.set_result(2)
320        res = torch.futures.wait_all([fut1, fut2])
321        print(res)
322        self.assertEqual(res, [1, 2])
323
324        # Version with an exception
325        def raise_in_fut(fut):
326            raise ValueError("Expected error")
327        fut3 = fut1.then(raise_in_fut)
328        with self.assertRaisesRegex(RuntimeError, "Expected error"):
329            torch.futures.wait_all([fut3, fut2])
330
331    def test_wait_none(self):
332        fut1 = Future[int]()
333        with self.assertRaisesRegex(RuntimeError, "Future can't be None"):
334            torch.jit.wait(None)
335        with self.assertRaisesRegex(RuntimeError, "Future can't be None"):
336            torch.futures.wait_all((None,))  # type: ignore[arg-type]
337        with self.assertRaisesRegex(RuntimeError, "Future can't be None"):
338            torch.futures.collect_all((fut1, None,))  # type: ignore[arg-type]
339
340if __name__ == '__main__':
341    run_tests()
342