xref: /aosp_15_r20/external/pytorch/test/dynamo/test_exceptions.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: dynamo"]
2
3import torch
4import torch._dynamo.config
5import torch._dynamo.test_case
6import torch._functorch.config
7import torch.nn
8import torch.utils.checkpoint
9
10
11class ExceptionTests(torch._dynamo.test_case.TestCase):
12    def test_exception(self):
13        def fn(x):
14            x = torch.cos(x)
15            try:
16                x = torch.sin(x)
17                raise NotImplementedError
18            except Exception:
19                x = torch.sigmoid(x)
20
21            return x
22
23        x = torch.randn(4)
24        ref = fn(x)
25        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
26        res = opt_fn(x)
27        self.assertEqual(ref, res)
28
29    def test_exception2(self):
30        def fn(x):
31            x = torch.cos(x)
32            try:
33                x = torch.sin(x)
34                raise NotImplementedError
35            except (NotImplementedError, AttributeError) as e:
36                x = torch.sigmoid(x)
37
38            return x
39
40        x = torch.randn(4)
41        ref = fn(x)
42        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
43        res = opt_fn(x)
44        self.assertEqual(ref, res)
45
46    def test_exception3(self):
47        def fn(x):
48            x = torch.cos(x)
49            try:
50                x = torch.sin(x)
51                raise NotImplementedError("Not implemented")
52            except AssertionError:
53                x = torch.sigmoid(x)
54            except NotImplementedError:
55                x = torch.cos(x)
56            finally:
57                x = torch.cos(x)
58
59            return x
60
61        x = torch.randn(4)
62        ref = fn(x)
63        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
64        res = opt_fn(x)
65        self.assertEqual(ref, res)
66
67    def test_exception4(self):
68        def fn(x):
69            for i in range(10):
70                if i == 5:
71                    return x
72                try:
73                    x = torch.sin(x)
74                    raise NotImplementedError
75                except Exception:
76                    x = torch.sigmoid(x)
77
78            return x
79
80        x = torch.randn(4)
81        ref = fn(x)
82        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
83        res = opt_fn(x)
84        self.assertEqual(ref, res)
85
86    def test_exception_with_another_exception(self):
87        def fn(x):
88            x = torch.cos(x)
89            try:
90                x = torch.sin(x)
91                raise NotImplementedError("Not implemented")
92            except NotImplementedError as e:
93                x = torch.sigmoid(x)
94                try:
95                    x = torch.cos(x)
96                    raise AssertionError
97                except AssertionError:
98                    x = torch.cos(x)
99
100        x = torch.randn(4)
101        ref = fn(x)
102        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
103        res = opt_fn(x)
104        self.assertEqual(ref, res)
105
106    def test_exception_else(self):
107        def gn(x):
108            return torch.cos(x)
109
110        def fn(x):
111            x = torch.cos(x)
112            try:
113                x = torch.sin(x)
114                x = gn(x)
115            except Exception:
116                x = torch.sigmoid(x)
117            else:
118                x = torch.cos(x)
119
120            return x
121
122        x = torch.randn(4)
123        ref = fn(x)
124        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
125        res = opt_fn(x)
126        self.assertEqual(ref, res)
127
128    # TODO(anijain2305) - does not work with fullgraph=True
129    def test_exception_with_another_exception2(self):
130        def gn(x):
131            try:
132                x = torch.cos(x)
133                raise NotImplementedError("Not implemented")
134            except NotImplementedError as e:
135                x = torch.sigmoid(x)
136                raise
137
138        def fn(x):
139            try:
140                x = torch.cos(x)
141                gn(x)
142            except Exception:
143                pass
144            return x
145
146        x = torch.randn(4)
147        ref = fn(x)
148        # Cant use fullgraph=True because RERAISE is not supported
149        opt_fn = torch.compile(fn, backend="eager")
150        res = opt_fn(x)
151
152    # TODO(anijain2305) - does not work with fullgraph=True
153    def test_exception_with_ctx_manager(self):
154        def fn(x):
155            x = torch.cos(x)
156            try:
157                with torch.no_grad():
158                    x = torch.sin(x)
159                    raise NotImplementedError("Not implemented")
160            except NotImplementedError as e:
161                x = torch.sigmoid(x)
162            return x
163
164        x = torch.randn(4)
165        ref = fn(x)
166        # Cant use fullgraph=True because WITH_EXCEPT_START is not supported
167        opt_fn = torch.compile(fn, backend="eager")
168        res = opt_fn(x)
169        self.assertEqual(ref, res)
170
171    def test_exception_raised_from_child(self):
172        def gn():
173            raise NotImplementedError("foo")
174
175        def fn(x):
176            x = torch.cos(x)
177            try:
178                x = torch.sin(x)
179                gn()
180                x = torch.sin(x)
181            except Exception:
182                x = torch.sigmoid(x)
183
184            return x
185
186        x = torch.randn(4)
187        ref = fn(x)
188        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
189        res = opt_fn(x)
190        self.assertEqual(ref, res)
191
192    def test_dynamo_undo_kw_names(self):
193        def g(x, k=None):
194            if k:
195                raise TypeError("error")
196            return x.sin()
197
198        def fn(x):
199            d = {"a": x}
200            try:
201                g(x, k=True)
202            except Exception:
203                y = 0
204                for _, b in d.items():  # noqa: PERF102
205                    y += b.sum()
206            return y
207
208        x = torch.randn(2, 3)
209        expected = fn(x)
210        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
211        got = opt_fn(x)
212        self.assertEqual(expected, got)
213
214    def test_nn_module_getattr(self):
215        class A:
216            def __init__(self) -> None:
217                self._b = 20
218
219            def __getattr__(self, name):
220                fixed_name = "_" + name
221                if fixed_name in self.__dict__:
222                    return self.__dict__[fixed_name]
223                raise AttributeError(f"{name} absent")
224
225        class B(A):
226            def __init__(self) -> None:
227                self.a = 10
228
229            def __getattr__(self, name):
230                try:
231                    return super().__getattr__(name)
232                except AttributeError:
233                    return 30
234
235        obj = B()
236
237        def fn(x):
238            return x * obj.a * obj.b * obj.c
239
240        x = torch.ones(4)
241        ref = fn(x)
242        print(ref)
243        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
244        res = opt_fn(x)
245        self.assertEqual(ref, res)
246
247    @torch._dynamo.config.patch(inline_inbuilt_nn_modules=True)
248    def test_custom_getattr_on_module_exception(self):
249        class Foo(torch.nn.Module):
250            def __init__(self, a=3):
251                super().__init__()
252                self.register_parameter("a", torch.nn.Parameter(torch.ones(4) * 2))
253
254            def __getattr__(self, name):
255                try:
256                    return super().__getattr__(name)  # defer to nn.Module's logic
257                except AttributeError:
258                    if name == "a_copy":
259                        return self.a
260                    raise
261
262            def forward(self, x):
263                return x * self.a * self.a_copy
264
265        mod = Foo()
266        opt_mod = torch.compile(mod, backend="eager", fullgraph=True)
267
268        x = torch.ones(4)
269        self.assertEqual(mod(x), opt_mod(x))
270
271    def test_attribute_error_from_getattr(self):
272        class Mock:
273            def __init__(self):
274                self.a = 5
275
276            def __getattr__(self, name):
277                if name != "a":
278                    raise AttributeError("missing")
279                return self.__dict__["a"]
280
281        mock = Mock()
282
283        def fn(x):
284            if hasattr(mock, "b"):
285                return torch.cos(x)
286            return torch.sin(x)
287
288        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
289        x = torch.randn(4)
290        ref = fn(x)
291        res = opt_fn(x)
292        self.assertEqual(ref, res)
293
294    def test_stop_iteration(self):
295        def zip_longest(*iterables, fillvalue=None):
296            # Get the iterators for each iterable
297            iterators = [iter(it) for it in iterables]
298
299            result = []
300            while True:
301                for it in iterators:
302                    try:
303                        value = next(it)
304                    except StopIteration:
305                        result.append(fillvalue)
306                        return result
307                    result.append(value)
308
309        def fn(x, y):
310            torch.cos(torch.randn(4))
311            return tuple(zip_longest(x, y))
312
313        x = [1, 2, 3, 4]
314        y = [10, 11, 12]
315
316        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
317        ref = fn(x, y)
318        res = opt_fn(x, y)
319        self.assertEqual(ref, res)
320
321    def test_nn_reraise(self):
322        class M(torch.nn.Module):
323            def forward(self, x):
324                raise ValueError("woof")
325                return x + 2
326
327        m = M()
328        m.register_forward_pre_hook(lambda m, go: None)
329
330        torch._dynamo.utils.clear_compilation_metrics()
331        opt_call = torch.compile(lambda x: m(x), backend="eager")
332        self.assertRaises(ValueError, lambda: opt_call(torch.randn(3)))
333        metrics = torch._dynamo.utils.get_compilation_metrics()
334        self.assertEqual(metrics[0].fail_reason, "Observed exception")
335
336    def test_key_error(self):
337        def fn(x, d):
338            try:
339                a = d["b"]
340            except KeyError:
341                a = 2
342            return x * a
343
344        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
345        x = torch.randn(4)
346        d = {"a": 1}
347        ref = fn(x, d)
348        res = opt_fn(x, d)
349        self.assertEqual(ref, res)
350
351    def test_atrribute_error(self):
352        class Mock:
353            def __init__(self):
354                self.a = 1
355
356        mock = Mock()
357
358        def fn(x):
359            try:
360                c = 2
361                mock.b
362            except AttributeError:
363                c = 3
364            return torch.sin(x) * c
365
366        opt_fn = torch.compile(fn, backend="eager")
367        x = torch.randn(4)
368        ref = fn(x)
369        res = opt_fn(x)
370        self.assertEqual(ref, res)
371
372    def test_raise_from_None(self):
373        # Inspired from os.environ
374        class MyMapping:
375            def __init__(self, d):
376                self._d = d
377
378            def __getitem__(self, key):
379                try:
380                    value = self._d[key]
381                except KeyError:
382                    raise KeyError(key) from None
383                return value
384
385        d = MyMapping({"a": 10, "b": 20})
386
387        def mapping_get(obj, key, value=None):
388            try:
389                return obj.__getitem__(key)
390            except KeyError:
391                return value
392
393        def fn(x, d, key):
394            x = torch.sin(x + 1)
395            return x, mapping_get(d, key)
396
397        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
398
399        x = torch.rand(2, 3)
400        ref = fn(x, d, "m")
401        res = opt_fn(x, d, "m")
402        self.assertEqual(ref[0], res[0])
403        self.assertEqual(ref[1], res[1])
404
405
406if __name__ == "__main__":
407    from torch._dynamo.test_case import run_tests
408
409    run_tests()
410