xref: /aosp_15_r20/external/pytorch/test/test_prims.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: decompositions"]
2
3from functools import partial
4from itertools import product
5import unittest
6
7import torch
8from torch.testing import make_tensor
9from torch.testing._internal.common_utils import (parametrize, run_tests, TestCase, TEST_SCIPY,
10                                                  set_default_dtype)
11from torch.testing._internal.common_device_type import (
12    instantiate_device_type_tests,
13    onlyCUDA,
14    dtypes,
15    OpDTypes,
16)
17from torch.testing._internal.common_methods_invocations import (
18    op_db,
19)
20from torch.testing._internal.common_device_type import (
21    ops,
22)
23
24from torch.testing._internal.logging_tensor import LoggingTensor, capture_logs, log_input
25import torch._prims as prims
26from torch._prims_common import CUDARngStateHelper
27from torch._prims.executor import make_traced
28import torch._refs as refs
29
30
31if TEST_SCIPY:
32    import scipy.special
33
34NVPRIM_ATEN_FALLBACK_WARNING = "fallback to aten executor"
35GET_ISOLATED_GRAPHMODULE_ERROR = "get_isolated_graphmodule failed on decomposition"
36
37class TestPrims(TestCase):
38    @onlyCUDA
39    @dtypes(torch.float32)
40    def test_broadcast_in_dim(self, device, dtype):
41        def _wrapper(a, b, broadcast_dimensions):
42            return prims.broadcast_in_dim(a, b.shape, broadcast_dimensions)
43
44        traced = make_traced(_wrapper)
45        make_arg = partial(make_tensor, device=device, dtype=dtype)
46
47        for executor in ('aten',):
48            fn = partial(traced, executor=executor)
49            # Same shape
50            shape = (5, 5)
51            a = make_arg(shape)
52            b = make_arg(shape, low=0.0, high=0.0)
53            result = fn(a, b, (0, 1))
54
55            self.assertEqual(result.shape, a.shape)
56            self.assertTrue(result.is_contiguous)
57            self.assertEqual(a, result)
58
59            # Error input: reordering dims
60            with self.assertRaises(Exception):
61                result = fn(a, b, (1, 0))
62
63            # Adding outermost dimensions
64            a = make_arg((5, 5))
65            b = make_arg((3, 3, 5, 5), low=0.0, high=0.0)
66            result = fn(a, b, (2, 3))
67
68            self.assertEqual(result.shape, b.shape)
69            self.assertEqual(a.broadcast_to(b.shape), result)
70
71            # Expands
72            a = make_arg((1, 5, 1))
73            b = make_arg((3, 5, 7), low=0.0, high=0.0)
74            result = fn(a, b, (0, 1, 2))
75
76            self.assertEqual(result.shape, b.shape)
77            self.assertEqual(a.expand_as(result), result)
78
79            # Unsqueezes
80            a = make_arg((1, 2, 3))
81            b = make_arg((1, 2, 1, 3), low=0.0, high=0.0)
82            result = fn(a, b, (0, 1, 3))
83
84            self.assertEqual(result.shape, b.shape)
85            self.assertEqual(a.unsqueeze(2), result)
86
87    @onlyCUDA
88    @dtypes(torch.float32)
89    def test_broadcast_in_dim_sum(self, device, dtype):
90        def _wrapper(a):
91            a_sum = prims.sum(a, [0, 1])
92            a_bc = prims.broadcast_in_dim(a_sum, [], [])
93            return a_bc
94
95        traced = make_traced(_wrapper)
96        make_arg = partial(make_tensor, device=device, dtype=dtype)
97
98        for executor in ('aten',):
99            fn = partial(traced, executor=executor)
100            shape = (5, 5)
101            a = make_arg(shape)
102            result = fn(a)
103
104            self.assertEqual(result.shape, ())
105            self.assertTrue(result.is_contiguous)
106            self.assertEqual(_wrapper(a), result)
107
108    @unittest.skipIf(not TEST_SCIPY, "SciPy not found")
109    @dtypes(torch.float64, torch.long)
110    def test_cbrt_prim(self, device, dtype):
111        make_arg = partial(make_tensor, device=device, dtype=dtype)
112        batches = [(), (1,), (2,), (0, 1), (1, 1), (2, 2)]
113        shapes = [(), (0,), (1,), (5,)]
114
115        # Sets the default dtype to NumPy's default dtype of double
116        with set_default_dtype(torch.double):
117            # Tested here, as this OP is not currently exposed or tested in ATen
118            for b, s in product(batches, shapes):
119                x = make_arg(b + s)
120                y = prims.cbrt(x)
121
122                x_np = x.cpu().numpy()
123                y_np = scipy.special.cbrt(x_np)
124
125                self.assertEqual(y, y_np, exact_device=False)
126
127    @dtypes(torch.float32)
128    def test_collapse(self, device, dtype):
129        t = torch.rand(2, 2, 2)
130        dim_ranges = [(0, 0), (0, 1), (1, 2), (0, 2)]
131        expected_shapes = [(2, 2, 2), (4, 2), (2, 4), (8,)]
132
133        for (start, end), shape in zip(dim_ranges, expected_shapes):
134            expect = t.reshape(shape)
135
136            copy = prims.collapse(t, start, end)
137            self.assertEqual(copy, expect)
138            self.assertFalse(copy._is_view())
139
140            view = prims.collapse_view(t, start, end)
141            self.assertEqual(view, expect)
142            self.assertTrue(view._is_view())
143
144        t_discontig = t.transpose(0, 1)
145        with self.assertRaises(ValueError, msg="no such view exists"):
146            view = prims.collapse_view(t_discontig, 0, 2)
147
148        copy = prims.collapse(t_discontig, 0, 1)
149        self.assertEqual(copy, t_discontig.reshape(4, 2))
150
151        error_dims = [(-1, 1), (0, 3), (1, -1)]
152        for start, end in error_dims:
153            for fn in [prims.collapse, prims.collapse_view]:
154                with self.assertRaises(AssertionError):
155                    fn(t, start, end)
156
157
158    def test_aten_overload_to_prims(self, device):
159        # This test is to ensure that the torch.ops.aten calls are replaced with refs
160        from torch.fx.experimental.proxy_tensor import make_fx
161        from torch._prims.context import TorchRefsMode
162
163        a = torch.randn(3, 3, device=device)
164
165        def func(a):
166            return torch.ops.aten.sigmoid.default(torch.ops.aten.digamma.default(a))
167
168        with TorchRefsMode():
169            gm = make_fx(func)(a)
170
171        # Check that all call_function nodes are prims
172        call_function_nodes = list(filter(lambda n: n.op == "call_function", gm.graph.nodes))
173        all_prims_namespace = all(
174            node.target.name().startswith("prims") for node in call_function_nodes
175        )
176        self.assertTrue(all_prims_namespace)
177
178    @onlyCUDA
179    @dtypes(torch.float32)
180    @parametrize("correction", [0, 1])
181    def test_var(self, device, dtype, correction):
182        def _wrapper(a):
183            return prims.var(a, [0, 1], correction=correction)
184
185        traced = make_traced(_wrapper)
186        make_arg = partial(make_tensor, device=device, dtype=dtype)
187
188        for executor in ('aten',):
189            fn = partial(traced, executor=executor)
190            shape = (5, 5)
191            a = make_arg(shape)
192            result = fn(a)
193
194            self.assertEqual(result.shape, ())
195            self.assertTrue(result.is_contiguous)
196            self.assertEqual(_wrapper(a), result)
197
198    @dtypes(torch.float32)
199    def test_memory_format_strides(self, device, dtype):
200        shapes = (
201            (),
202            (0,),
203            (1,),
204            (5),
205            (1, 0),
206            (1, 1),
207            (3, 7),
208            (3, 0, 2),
209            (1, 1, 2),
210            (4, 1, 1),
211            (7, 8, 9),
212        )
213
214        channels_last_shapes = (
215            (0, 0, 0, 0),
216            (1, 0, 3, 0),
217            (0, 2, 3, 5),
218            (2, 2, 2, 0),
219            (5, 4, 3, 2),
220            (8, 8, 7, 2),
221            (9, 1, 3, 1),
222            (4, 5, 8, 7)
223        )
224
225        channels_last_3d_shapes = (
226            (0, 8, 7, 9, 2),
227            (5, 0, 7, 9, 2),
228            (5, 0, 7, 9, 0),
229            (5, 8, 7, 9, 2),
230            (5, 1, 7, 9, 2),
231            (5, 1, 7, 9, 1),
232        )
233
234        pairs = (
235            (shapes, torch.contiguous_format),
236            (channels_last_shapes, torch.contiguous_format),
237            (channels_last_3d_shapes, torch.contiguous_format),
238            (channels_last_shapes, torch.channels_last),
239            (channels_last_3d_shapes, torch.channels_last_3d),
240        )
241
242        for shapes, memory_format in pairs:
243            for shape in shapes:
244                # tests empty
245                expected = torch.empty(shape, device=device, dtype=dtype, memory_format=memory_format)
246                actual = refs.empty(shape, device=device, dtype=dtype, memory_format=memory_format)
247                self.assertEqual(expected.stride(), actual.stride())
248
249                # tests clone
250                a = torch.testing.make_tensor(shape, device=device, dtype=dtype)
251                expected = torch.clone(a, memory_format=memory_format)
252                actual = torch.clone(a, memory_format=memory_format)
253                self.assertEqual(expected.stride(), actual.stride())
254
255                # tests contiguous
256                a = torch.testing.make_tensor(shape, device=device, dtype=dtype, noncontiguous=True)
257                expected = a.contiguous(memory_format=memory_format)
258                actual = refs.contiguous(a, memory_format=memory_format)
259                self.assertEqual(expected.stride(), actual.stride())
260
261    @dtypes(torch.float32)
262    def test_reshape_view_method(self, device, dtype):
263        make_arg = partial(make_tensor, device=device, dtype=dtype)
264        a = make_arg((5, 5))
265        new_shape = 1, 5, 1, 5
266        result_eager = a.reshape(*new_shape)
267        result_refs = refs.reshape(a, *new_shape)
268        self.assertEqual(result_eager, result_refs)
269
270        result_eager = a.view(*new_shape)
271        result_refs = refs.view(a, *new_shape)
272        self.assertEqual(result_eager, result_refs)
273
274
275    @onlyCUDA
276    @dtypes(torch.float32)
277    def test_philox_rand(self, device, dtype):
278        sizes = (1000, 1000000)  # offsets of 4 and 8
279        repeats = 2  # Checks multiple rand calls results with multiple philox_rand calls
280        for size in sizes:
281            torch.cuda.manual_seed(123)
282            references = []
283            results = []
284            rng_states = []
285            for _ in range(repeats):
286                rng_states.append(CUDARngStateHelper.get_torch_state_as_tuple())
287                references.append(torch.rand(size, device=device, dtype=dtype))
288
289            torch.cuda.manual_seed(123)
290            for idx in range(repeats):
291                seed, offset = rng_states[idx]
292                result, _ = torch.ops.rngprims.philox_rand((size,),
293                                                           seed=seed,
294                                                           offset=offset,
295                                                           stride=None,
296                                                           device=device,
297                                                           dtype=dtype)
298                results.append(result)
299
300            for a, b in zip(references, results):
301                self.assertEqual(a, b)
302
303
304    @dtypes(torch.float32)
305    def test_functional_rng_wrappers(self, device, dtype):
306
307        torch.manual_seed(123)
308        ref1 = torch.rand(10, device=device, dtype=dtype)
309        ref2 = torch.rand(10, device=device, dtype=dtype)
310
311
312        torch.manual_seed(123)
313        rng_state1, res1 = torch._prims.rng_prims.run_and_save_rng_state(torch.rand, 10, device=device, dtype=dtype)
314        rng_state2, res2 = torch._prims.rng_prims.run_and_save_rng_state(torch.rand, 10, device=device, dtype=dtype)
315
316        res3 = torch._prims.rng_prims.run_with_rng_state(rng_state1, torch.rand, 10, device=device, dtype=dtype)
317        res4 = torch._prims.rng_prims.run_with_rng_state(rng_state2, torch.rand, 10, device=device, dtype=dtype)
318
319        self.assertEqual(ref1, res1)
320        self.assertEqual(ref2, res2)
321        self.assertEqual(ref1, res3)
322        self.assertEqual(ref2, res4)
323
324class TestPrimsBasic(TestCase):
325    def test_torch_ops(self):
326        r = make_tensor((2,), device='cpu', dtype=torch.float)
327        self.assertEqual(torch.ops.prims.sin(r), torch.sin(r))
328
329        r = LoggingTensor(r)
330        with capture_logs() as logs:
331            log_input("input", r)
332            prims.sin(r)
333        self.assertExpectedInline('\n'.join(logs), """\
334$0: f32[2] = input('input')
335$1: f32[2] = torch._ops.prims.sin.default($0)""")
336
337    def test_mul_complex(self):
338        prims.mul(torch.randn(2), 1 + 1j)
339
340    def test_clone_complex(self):
341        with torch._dispatch.python.enable_python_dispatcher():
342            x = torch.randn(4, dtype=torch.complex64, device='meta').conj()
343            out = x + 1
344
345    def test_check_deprecation_warning(self):
346        with self.assertWarnsRegex(FutureWarning, 'will be removed in the future'):
347            torch._prims_common.check(True, lambda: 'message')
348
349
350instantiate_device_type_tests(TestPrims, globals())
351
352
353class TestRefs(TestCase):
354    @dtypes(torch.float32)
355    def test_constant_pad_nd_memory_format(self, device, dtype):
356        # Test memory format is preserved in unambiguous cases
357        for mf, ndim in (
358                (torch.channels_last, 4),
359                (torch.contiguous_format, 4),
360                (torch.channels_last_3d, 5),
361                (torch.contiguous_format, 5),
362        ):
363            a = torch.zeros([2] * ndim).to(memory_format=mf)
364            res = refs.constant_pad_nd(a, pad=[1] * (2 * ndim))
365            self.assertTrue(res.is_contiguous(memory_format=mf))
366
367        # Ambiguous cases
368
369        # is_channels_last_ and is_contiguous_, results in channels_last output
370        a = torch.empty_strided((2, 1, 2, 2), stride=(4, 1, 2, 1))
371        self.assertTrue(a.is_contiguous(memory_format=torch.channels_last))
372        self.assertTrue(a.is_contiguous())
373        actual = refs.constant_pad_nd(a, pad=[1] * 8)
374        expect = torch.constant_pad_nd(a, pad=[1] * 8)
375        self.assertEqual(actual.stride(), expect.stride())
376        self.assertTrue(actual.is_contiguous(memory_format=torch.channels_last))
377
378        # is_channels_last_contiguous_ but not is_channels_last_, results in
379        # contiguous output
380        a = torch.empty_strided((2, 1, 2, 2), stride=(4, 4, 2, 1))
381        self.assertTrue(a.is_contiguous(memory_format=torch.channels_last))
382        self.assertTrue(a.is_contiguous())
383        actual = refs.constant_pad_nd(a, pad=[1] * 8)
384        expect = torch.constant_pad_nd(a, pad=[1] * 8)
385        self.assertEqual(actual.stride(), expect.stride())
386        self.assertTrue(actual.is_contiguous())
387
388    def test_unbind(self):
389        # If unbind returns empty tuple, it breaks some assumptions in some backward tests in test_ops.py.
390        # So can't put this test into common_methods_invocations.py.
391        a = torch.rand([3, 0, 4])
392        actual = refs.unbind(a, 1)
393        expect = torch.unbind(a, 1)
394        self.assertEqual(actual, expect)
395
396    def test_logspace_with_complex_input(self):
397        actual = refs.logspace(2, 10 + 5j, steps=5)
398        expect = torch.logspace(2, 10 + 5j, steps=5)
399        self.assertEqual(actual, expect)
400
401    def test_linspace_with_complex_input(self):
402        actual = refs.linspace(2, 10 + 5j, steps=5)
403        expect = torch.linspace(2, 10 + 5j, steps=5)
404        self.assertEqual(actual, expect)
405
406    # From https://github.com/pytorch/pytorch/issues/109558
407    def test_infinite_loop_from_py_dispatcher(self):
408        # enables prim decomps
409        with torch._dispatch.python.enable_python_dispatcher():
410            x = torch.ones(4)
411            y = x.to(device="meta")
412
413    def test_inferred_tags(self):
414        self.assertEqual(torch.ops.prims.normal.default.tags, (torch.Tag.nondeterministic_seeded, torch.Tag.pt2_compliant_tag))
415
416
417
418instantiate_device_type_tests(TestRefs, globals())
419
420
421class TestDecomp(TestCase):
422    @ops([op for op in op_db if op.supports_varargs], dtypes=OpDTypes.any_one)
423    def test_decomposition_method_vararg(self, device, dtype, op):
424        # some ops have vararg variants for the methods. this tests it.
425        # we don't have tests for varargs in OpInfo, so we need to
426        # improvise this a bit.
427        # The rule for general functions (the special cases being e.g. tensor
428        # creation functions taking shapes) is that things can be vararg
429        # if the method has only one argument of sequence type.
430        # e.g. permute can be called on a 3d tensor t as t.permute(0, 2, 1)
431        #      as well as t.permute([0, 2, 1])
432        #      when the signature in native_functions.yaml
433        #      shows arguments Tensor self, IntList dims
434        # we might need to adjust things for the factory functions or
435        # have them do their own test
436        from torch.fx.experimental.proxy_tensor import make_fx
437        from torch._prims.context import TorchRefsMode
438
439        # filter out empty tuple as that cannot be the varargs
440        sample_inputs = (si for si in op.sample_inputs(device, dtype, requires_grad=False)
441                         if (si.args[-1] if si.args else si.input))
442
443        # just run one test, we assume there is a suitable one in the tests
444        sample_input = next(sample_inputs)
445        all_args = (sample_input.input,) + sample_input.args
446
447        # in general, the methods take varargs and not (always?) the function
448        # variants, the exception to this rule are the factory functions
449        if op.is_factory_function:
450            fn = op.op
451        else:
452            fn = op.method_variant
453        with TorchRefsMode():
454            gm = make_fx(fn)(*all_args[:-1], *all_args[-1])
455
456        # in case we add random factory functions
457        torch.manual_seed(1)
458        res = gm(*all_args[:-1], *all_args[-1])
459        torch.manual_seed(1)
460        expected = fn(*all_args[:-1], *all_args[-1])
461        self.assertEqual(res, expected)
462
463
464instantiate_device_type_tests(TestDecomp, globals())
465
466
467if __name__ == "__main__":
468    run_tests()
469