xref: /aosp_15_r20/external/pytorch/test/test_prims.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: decompositions"]
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerfrom functools import partial
4*da0073e9SAndroid Build Coastguard Workerfrom itertools import product
5*da0073e9SAndroid Build Coastguard Workerimport unittest
6*da0073e9SAndroid Build Coastguard Worker
7*da0073e9SAndroid Build Coastguard Workerimport torch
8*da0073e9SAndroid Build Coastguard Workerfrom torch.testing import make_tensor
9*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import (parametrize, run_tests, TestCase, TEST_SCIPY,
10*da0073e9SAndroid Build Coastguard Worker                                                  set_default_dtype)
11*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_device_type import (
12*da0073e9SAndroid Build Coastguard Worker    instantiate_device_type_tests,
13*da0073e9SAndroid Build Coastguard Worker    onlyCUDA,
14*da0073e9SAndroid Build Coastguard Worker    dtypes,
15*da0073e9SAndroid Build Coastguard Worker    OpDTypes,
16*da0073e9SAndroid Build Coastguard Worker)
17*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_methods_invocations import (
18*da0073e9SAndroid Build Coastguard Worker    op_db,
19*da0073e9SAndroid Build Coastguard Worker)
20*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_device_type import (
21*da0073e9SAndroid Build Coastguard Worker    ops,
22*da0073e9SAndroid Build Coastguard Worker)
23*da0073e9SAndroid Build Coastguard Worker
24*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.logging_tensor import LoggingTensor, capture_logs, log_input
25*da0073e9SAndroid Build Coastguard Workerimport torch._prims as prims
26*da0073e9SAndroid Build Coastguard Workerfrom torch._prims_common import CUDARngStateHelper
27*da0073e9SAndroid Build Coastguard Workerfrom torch._prims.executor import make_traced
28*da0073e9SAndroid Build Coastguard Workerimport torch._refs as refs
29*da0073e9SAndroid Build Coastguard Worker
30*da0073e9SAndroid Build Coastguard Worker
31*da0073e9SAndroid Build Coastguard Workerif TEST_SCIPY:
32*da0073e9SAndroid Build Coastguard Worker    import scipy.special
33*da0073e9SAndroid Build Coastguard Worker
34*da0073e9SAndroid Build Coastguard WorkerNVPRIM_ATEN_FALLBACK_WARNING = "fallback to aten executor"
35*da0073e9SAndroid Build Coastguard WorkerGET_ISOLATED_GRAPHMODULE_ERROR = "get_isolated_graphmodule failed on decomposition"
36*da0073e9SAndroid Build Coastguard Worker
37*da0073e9SAndroid Build Coastguard Workerclass TestPrims(TestCase):
38*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
39*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float32)
40*da0073e9SAndroid Build Coastguard Worker    def test_broadcast_in_dim(self, device, dtype):
41*da0073e9SAndroid Build Coastguard Worker        def _wrapper(a, b, broadcast_dimensions):
42*da0073e9SAndroid Build Coastguard Worker            return prims.broadcast_in_dim(a, b.shape, broadcast_dimensions)
43*da0073e9SAndroid Build Coastguard Worker
44*da0073e9SAndroid Build Coastguard Worker        traced = make_traced(_wrapper)
45*da0073e9SAndroid Build Coastguard Worker        make_arg = partial(make_tensor, device=device, dtype=dtype)
46*da0073e9SAndroid Build Coastguard Worker
47*da0073e9SAndroid Build Coastguard Worker        for executor in ('aten',):
48*da0073e9SAndroid Build Coastguard Worker            fn = partial(traced, executor=executor)
49*da0073e9SAndroid Build Coastguard Worker            # Same shape
50*da0073e9SAndroid Build Coastguard Worker            shape = (5, 5)
51*da0073e9SAndroid Build Coastguard Worker            a = make_arg(shape)
52*da0073e9SAndroid Build Coastguard Worker            b = make_arg(shape, low=0.0, high=0.0)
53*da0073e9SAndroid Build Coastguard Worker            result = fn(a, b, (0, 1))
54*da0073e9SAndroid Build Coastguard Worker
55*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(result.shape, a.shape)
56*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(result.is_contiguous)
57*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(a, result)
58*da0073e9SAndroid Build Coastguard Worker
59*da0073e9SAndroid Build Coastguard Worker            # Error input: reordering dims
60*da0073e9SAndroid Build Coastguard Worker            with self.assertRaises(Exception):
61*da0073e9SAndroid Build Coastguard Worker                result = fn(a, b, (1, 0))
62*da0073e9SAndroid Build Coastguard Worker
63*da0073e9SAndroid Build Coastguard Worker            # Adding outermost dimensions
64*da0073e9SAndroid Build Coastguard Worker            a = make_arg((5, 5))
65*da0073e9SAndroid Build Coastguard Worker            b = make_arg((3, 3, 5, 5), low=0.0, high=0.0)
66*da0073e9SAndroid Build Coastguard Worker            result = fn(a, b, (2, 3))
67*da0073e9SAndroid Build Coastguard Worker
68*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(result.shape, b.shape)
69*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(a.broadcast_to(b.shape), result)
70*da0073e9SAndroid Build Coastguard Worker
71*da0073e9SAndroid Build Coastguard Worker            # Expands
72*da0073e9SAndroid Build Coastguard Worker            a = make_arg((1, 5, 1))
73*da0073e9SAndroid Build Coastguard Worker            b = make_arg((3, 5, 7), low=0.0, high=0.0)
74*da0073e9SAndroid Build Coastguard Worker            result = fn(a, b, (0, 1, 2))
75*da0073e9SAndroid Build Coastguard Worker
76*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(result.shape, b.shape)
77*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(a.expand_as(result), result)
78*da0073e9SAndroid Build Coastguard Worker
79*da0073e9SAndroid Build Coastguard Worker            # Unsqueezes
80*da0073e9SAndroid Build Coastguard Worker            a = make_arg((1, 2, 3))
81*da0073e9SAndroid Build Coastguard Worker            b = make_arg((1, 2, 1, 3), low=0.0, high=0.0)
82*da0073e9SAndroid Build Coastguard Worker            result = fn(a, b, (0, 1, 3))
83*da0073e9SAndroid Build Coastguard Worker
84*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(result.shape, b.shape)
85*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(a.unsqueeze(2), result)
86*da0073e9SAndroid Build Coastguard Worker
87*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
88*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float32)
89*da0073e9SAndroid Build Coastguard Worker    def test_broadcast_in_dim_sum(self, device, dtype):
90*da0073e9SAndroid Build Coastguard Worker        def _wrapper(a):
91*da0073e9SAndroid Build Coastguard Worker            a_sum = prims.sum(a, [0, 1])
92*da0073e9SAndroid Build Coastguard Worker            a_bc = prims.broadcast_in_dim(a_sum, [], [])
93*da0073e9SAndroid Build Coastguard Worker            return a_bc
94*da0073e9SAndroid Build Coastguard Worker
95*da0073e9SAndroid Build Coastguard Worker        traced = make_traced(_wrapper)
96*da0073e9SAndroid Build Coastguard Worker        make_arg = partial(make_tensor, device=device, dtype=dtype)
97*da0073e9SAndroid Build Coastguard Worker
98*da0073e9SAndroid Build Coastguard Worker        for executor in ('aten',):
99*da0073e9SAndroid Build Coastguard Worker            fn = partial(traced, executor=executor)
100*da0073e9SAndroid Build Coastguard Worker            shape = (5, 5)
101*da0073e9SAndroid Build Coastguard Worker            a = make_arg(shape)
102*da0073e9SAndroid Build Coastguard Worker            result = fn(a)
103*da0073e9SAndroid Build Coastguard Worker
104*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(result.shape, ())
105*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(result.is_contiguous)
106*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(_wrapper(a), result)
107*da0073e9SAndroid Build Coastguard Worker
108*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_SCIPY, "SciPy not found")
109*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float64, torch.long)
110*da0073e9SAndroid Build Coastguard Worker    def test_cbrt_prim(self, device, dtype):
111*da0073e9SAndroid Build Coastguard Worker        make_arg = partial(make_tensor, device=device, dtype=dtype)
112*da0073e9SAndroid Build Coastguard Worker        batches = [(), (1,), (2,), (0, 1), (1, 1), (2, 2)]
113*da0073e9SAndroid Build Coastguard Worker        shapes = [(), (0,), (1,), (5,)]
114*da0073e9SAndroid Build Coastguard Worker
115*da0073e9SAndroid Build Coastguard Worker        # Sets the default dtype to NumPy's default dtype of double
116*da0073e9SAndroid Build Coastguard Worker        with set_default_dtype(torch.double):
117*da0073e9SAndroid Build Coastguard Worker            # Tested here, as this OP is not currently exposed or tested in ATen
118*da0073e9SAndroid Build Coastguard Worker            for b, s in product(batches, shapes):
119*da0073e9SAndroid Build Coastguard Worker                x = make_arg(b + s)
120*da0073e9SAndroid Build Coastguard Worker                y = prims.cbrt(x)
121*da0073e9SAndroid Build Coastguard Worker
122*da0073e9SAndroid Build Coastguard Worker                x_np = x.cpu().numpy()
123*da0073e9SAndroid Build Coastguard Worker                y_np = scipy.special.cbrt(x_np)
124*da0073e9SAndroid Build Coastguard Worker
125*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(y, y_np, exact_device=False)
126*da0073e9SAndroid Build Coastguard Worker
127*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float32)
128*da0073e9SAndroid Build Coastguard Worker    def test_collapse(self, device, dtype):
129*da0073e9SAndroid Build Coastguard Worker        t = torch.rand(2, 2, 2)
130*da0073e9SAndroid Build Coastguard Worker        dim_ranges = [(0, 0), (0, 1), (1, 2), (0, 2)]
131*da0073e9SAndroid Build Coastguard Worker        expected_shapes = [(2, 2, 2), (4, 2), (2, 4), (8,)]
132*da0073e9SAndroid Build Coastguard Worker
133*da0073e9SAndroid Build Coastguard Worker        for (start, end), shape in zip(dim_ranges, expected_shapes):
134*da0073e9SAndroid Build Coastguard Worker            expect = t.reshape(shape)
135*da0073e9SAndroid Build Coastguard Worker
136*da0073e9SAndroid Build Coastguard Worker            copy = prims.collapse(t, start, end)
137*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(copy, expect)
138*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(copy._is_view())
139*da0073e9SAndroid Build Coastguard Worker
140*da0073e9SAndroid Build Coastguard Worker            view = prims.collapse_view(t, start, end)
141*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(view, expect)
142*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(view._is_view())
143*da0073e9SAndroid Build Coastguard Worker
144*da0073e9SAndroid Build Coastguard Worker        t_discontig = t.transpose(0, 1)
145*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(ValueError, msg="no such view exists"):
146*da0073e9SAndroid Build Coastguard Worker            view = prims.collapse_view(t_discontig, 0, 2)
147*da0073e9SAndroid Build Coastguard Worker
148*da0073e9SAndroid Build Coastguard Worker        copy = prims.collapse(t_discontig, 0, 1)
149*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(copy, t_discontig.reshape(4, 2))
150*da0073e9SAndroid Build Coastguard Worker
151*da0073e9SAndroid Build Coastguard Worker        error_dims = [(-1, 1), (0, 3), (1, -1)]
152*da0073e9SAndroid Build Coastguard Worker        for start, end in error_dims:
153*da0073e9SAndroid Build Coastguard Worker            for fn in [prims.collapse, prims.collapse_view]:
154*da0073e9SAndroid Build Coastguard Worker                with self.assertRaises(AssertionError):
155*da0073e9SAndroid Build Coastguard Worker                    fn(t, start, end)
156*da0073e9SAndroid Build Coastguard Worker
157*da0073e9SAndroid Build Coastguard Worker
158*da0073e9SAndroid Build Coastguard Worker    def test_aten_overload_to_prims(self, device):
159*da0073e9SAndroid Build Coastguard Worker        # This test is to ensure that the torch.ops.aten calls are replaced with refs
160*da0073e9SAndroid Build Coastguard Worker        from torch.fx.experimental.proxy_tensor import make_fx
161*da0073e9SAndroid Build Coastguard Worker        from torch._prims.context import TorchRefsMode
162*da0073e9SAndroid Build Coastguard Worker
163*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(3, 3, device=device)
164*da0073e9SAndroid Build Coastguard Worker
165*da0073e9SAndroid Build Coastguard Worker        def func(a):
166*da0073e9SAndroid Build Coastguard Worker            return torch.ops.aten.sigmoid.default(torch.ops.aten.digamma.default(a))
167*da0073e9SAndroid Build Coastguard Worker
168*da0073e9SAndroid Build Coastguard Worker        with TorchRefsMode():
169*da0073e9SAndroid Build Coastguard Worker            gm = make_fx(func)(a)
170*da0073e9SAndroid Build Coastguard Worker
171*da0073e9SAndroid Build Coastguard Worker        # Check that all call_function nodes are prims
172*da0073e9SAndroid Build Coastguard Worker        call_function_nodes = list(filter(lambda n: n.op == "call_function", gm.graph.nodes))
173*da0073e9SAndroid Build Coastguard Worker        all_prims_namespace = all(
174*da0073e9SAndroid Build Coastguard Worker            node.target.name().startswith("prims") for node in call_function_nodes
175*da0073e9SAndroid Build Coastguard Worker        )
176*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(all_prims_namespace)
177*da0073e9SAndroid Build Coastguard Worker
178*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
179*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float32)
180*da0073e9SAndroid Build Coastguard Worker    @parametrize("correction", [0, 1])
181*da0073e9SAndroid Build Coastguard Worker    def test_var(self, device, dtype, correction):
182*da0073e9SAndroid Build Coastguard Worker        def _wrapper(a):
183*da0073e9SAndroid Build Coastguard Worker            return prims.var(a, [0, 1], correction=correction)
184*da0073e9SAndroid Build Coastguard Worker
185*da0073e9SAndroid Build Coastguard Worker        traced = make_traced(_wrapper)
186*da0073e9SAndroid Build Coastguard Worker        make_arg = partial(make_tensor, device=device, dtype=dtype)
187*da0073e9SAndroid Build Coastguard Worker
188*da0073e9SAndroid Build Coastguard Worker        for executor in ('aten',):
189*da0073e9SAndroid Build Coastguard Worker            fn = partial(traced, executor=executor)
190*da0073e9SAndroid Build Coastguard Worker            shape = (5, 5)
191*da0073e9SAndroid Build Coastguard Worker            a = make_arg(shape)
192*da0073e9SAndroid Build Coastguard Worker            result = fn(a)
193*da0073e9SAndroid Build Coastguard Worker
194*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(result.shape, ())
195*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(result.is_contiguous)
196*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(_wrapper(a), result)
197*da0073e9SAndroid Build Coastguard Worker
198*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float32)
199*da0073e9SAndroid Build Coastguard Worker    def test_memory_format_strides(self, device, dtype):
200*da0073e9SAndroid Build Coastguard Worker        shapes = (
201*da0073e9SAndroid Build Coastguard Worker            (),
202*da0073e9SAndroid Build Coastguard Worker            (0,),
203*da0073e9SAndroid Build Coastguard Worker            (1,),
204*da0073e9SAndroid Build Coastguard Worker            (5),
205*da0073e9SAndroid Build Coastguard Worker            (1, 0),
206*da0073e9SAndroid Build Coastguard Worker            (1, 1),
207*da0073e9SAndroid Build Coastguard Worker            (3, 7),
208*da0073e9SAndroid Build Coastguard Worker            (3, 0, 2),
209*da0073e9SAndroid Build Coastguard Worker            (1, 1, 2),
210*da0073e9SAndroid Build Coastguard Worker            (4, 1, 1),
211*da0073e9SAndroid Build Coastguard Worker            (7, 8, 9),
212*da0073e9SAndroid Build Coastguard Worker        )
213*da0073e9SAndroid Build Coastguard Worker
214*da0073e9SAndroid Build Coastguard Worker        channels_last_shapes = (
215*da0073e9SAndroid Build Coastguard Worker            (0, 0, 0, 0),
216*da0073e9SAndroid Build Coastguard Worker            (1, 0, 3, 0),
217*da0073e9SAndroid Build Coastguard Worker            (0, 2, 3, 5),
218*da0073e9SAndroid Build Coastguard Worker            (2, 2, 2, 0),
219*da0073e9SAndroid Build Coastguard Worker            (5, 4, 3, 2),
220*da0073e9SAndroid Build Coastguard Worker            (8, 8, 7, 2),
221*da0073e9SAndroid Build Coastguard Worker            (9, 1, 3, 1),
222*da0073e9SAndroid Build Coastguard Worker            (4, 5, 8, 7)
223*da0073e9SAndroid Build Coastguard Worker        )
224*da0073e9SAndroid Build Coastguard Worker
225*da0073e9SAndroid Build Coastguard Worker        channels_last_3d_shapes = (
226*da0073e9SAndroid Build Coastguard Worker            (0, 8, 7, 9, 2),
227*da0073e9SAndroid Build Coastguard Worker            (5, 0, 7, 9, 2),
228*da0073e9SAndroid Build Coastguard Worker            (5, 0, 7, 9, 0),
229*da0073e9SAndroid Build Coastguard Worker            (5, 8, 7, 9, 2),
230*da0073e9SAndroid Build Coastguard Worker            (5, 1, 7, 9, 2),
231*da0073e9SAndroid Build Coastguard Worker            (5, 1, 7, 9, 1),
232*da0073e9SAndroid Build Coastguard Worker        )
233*da0073e9SAndroid Build Coastguard Worker
234*da0073e9SAndroid Build Coastguard Worker        pairs = (
235*da0073e9SAndroid Build Coastguard Worker            (shapes, torch.contiguous_format),
236*da0073e9SAndroid Build Coastguard Worker            (channels_last_shapes, torch.contiguous_format),
237*da0073e9SAndroid Build Coastguard Worker            (channels_last_3d_shapes, torch.contiguous_format),
238*da0073e9SAndroid Build Coastguard Worker            (channels_last_shapes, torch.channels_last),
239*da0073e9SAndroid Build Coastguard Worker            (channels_last_3d_shapes, torch.channels_last_3d),
240*da0073e9SAndroid Build Coastguard Worker        )
241*da0073e9SAndroid Build Coastguard Worker
242*da0073e9SAndroid Build Coastguard Worker        for shapes, memory_format in pairs:
243*da0073e9SAndroid Build Coastguard Worker            for shape in shapes:
244*da0073e9SAndroid Build Coastguard Worker                # tests empty
245*da0073e9SAndroid Build Coastguard Worker                expected = torch.empty(shape, device=device, dtype=dtype, memory_format=memory_format)
246*da0073e9SAndroid Build Coastguard Worker                actual = refs.empty(shape, device=device, dtype=dtype, memory_format=memory_format)
247*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(expected.stride(), actual.stride())
248*da0073e9SAndroid Build Coastguard Worker
249*da0073e9SAndroid Build Coastguard Worker                # tests clone
250*da0073e9SAndroid Build Coastguard Worker                a = torch.testing.make_tensor(shape, device=device, dtype=dtype)
251*da0073e9SAndroid Build Coastguard Worker                expected = torch.clone(a, memory_format=memory_format)
252*da0073e9SAndroid Build Coastguard Worker                actual = torch.clone(a, memory_format=memory_format)
253*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(expected.stride(), actual.stride())
254*da0073e9SAndroid Build Coastguard Worker
255*da0073e9SAndroid Build Coastguard Worker                # tests contiguous
256*da0073e9SAndroid Build Coastguard Worker                a = torch.testing.make_tensor(shape, device=device, dtype=dtype, noncontiguous=True)
257*da0073e9SAndroid Build Coastguard Worker                expected = a.contiguous(memory_format=memory_format)
258*da0073e9SAndroid Build Coastguard Worker                actual = refs.contiguous(a, memory_format=memory_format)
259*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(expected.stride(), actual.stride())
260*da0073e9SAndroid Build Coastguard Worker
261*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float32)
262*da0073e9SAndroid Build Coastguard Worker    def test_reshape_view_method(self, device, dtype):
263*da0073e9SAndroid Build Coastguard Worker        make_arg = partial(make_tensor, device=device, dtype=dtype)
264*da0073e9SAndroid Build Coastguard Worker        a = make_arg((5, 5))
265*da0073e9SAndroid Build Coastguard Worker        new_shape = 1, 5, 1, 5
266*da0073e9SAndroid Build Coastguard Worker        result_eager = a.reshape(*new_shape)
267*da0073e9SAndroid Build Coastguard Worker        result_refs = refs.reshape(a, *new_shape)
268*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result_eager, result_refs)
269*da0073e9SAndroid Build Coastguard Worker
270*da0073e9SAndroid Build Coastguard Worker        result_eager = a.view(*new_shape)
271*da0073e9SAndroid Build Coastguard Worker        result_refs = refs.view(a, *new_shape)
272*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result_eager, result_refs)
273*da0073e9SAndroid Build Coastguard Worker
274*da0073e9SAndroid Build Coastguard Worker
275*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
276*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float32)
277*da0073e9SAndroid Build Coastguard Worker    def test_philox_rand(self, device, dtype):
278*da0073e9SAndroid Build Coastguard Worker        sizes = (1000, 1000000)  # offsets of 4 and 8
279*da0073e9SAndroid Build Coastguard Worker        repeats = 2  # Checks multiple rand calls results with multiple philox_rand calls
280*da0073e9SAndroid Build Coastguard Worker        for size in sizes:
281*da0073e9SAndroid Build Coastguard Worker            torch.cuda.manual_seed(123)
282*da0073e9SAndroid Build Coastguard Worker            references = []
283*da0073e9SAndroid Build Coastguard Worker            results = []
284*da0073e9SAndroid Build Coastguard Worker            rng_states = []
285*da0073e9SAndroid Build Coastguard Worker            for _ in range(repeats):
286*da0073e9SAndroid Build Coastguard Worker                rng_states.append(CUDARngStateHelper.get_torch_state_as_tuple())
287*da0073e9SAndroid Build Coastguard Worker                references.append(torch.rand(size, device=device, dtype=dtype))
288*da0073e9SAndroid Build Coastguard Worker
289*da0073e9SAndroid Build Coastguard Worker            torch.cuda.manual_seed(123)
290*da0073e9SAndroid Build Coastguard Worker            for idx in range(repeats):
291*da0073e9SAndroid Build Coastguard Worker                seed, offset = rng_states[idx]
292*da0073e9SAndroid Build Coastguard Worker                result, _ = torch.ops.rngprims.philox_rand((size,),
293*da0073e9SAndroid Build Coastguard Worker                                                           seed=seed,
294*da0073e9SAndroid Build Coastguard Worker                                                           offset=offset,
295*da0073e9SAndroid Build Coastguard Worker                                                           stride=None,
296*da0073e9SAndroid Build Coastguard Worker                                                           device=device,
297*da0073e9SAndroid Build Coastguard Worker                                                           dtype=dtype)
298*da0073e9SAndroid Build Coastguard Worker                results.append(result)
299*da0073e9SAndroid Build Coastguard Worker
300*da0073e9SAndroid Build Coastguard Worker            for a, b in zip(references, results):
301*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(a, b)
302*da0073e9SAndroid Build Coastguard Worker
303*da0073e9SAndroid Build Coastguard Worker
304*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float32)
305*da0073e9SAndroid Build Coastguard Worker    def test_functional_rng_wrappers(self, device, dtype):
306*da0073e9SAndroid Build Coastguard Worker
307*da0073e9SAndroid Build Coastguard Worker        torch.manual_seed(123)
308*da0073e9SAndroid Build Coastguard Worker        ref1 = torch.rand(10, device=device, dtype=dtype)
309*da0073e9SAndroid Build Coastguard Worker        ref2 = torch.rand(10, device=device, dtype=dtype)
310*da0073e9SAndroid Build Coastguard Worker
311*da0073e9SAndroid Build Coastguard Worker
312*da0073e9SAndroid Build Coastguard Worker        torch.manual_seed(123)
313*da0073e9SAndroid Build Coastguard Worker        rng_state1, res1 = torch._prims.rng_prims.run_and_save_rng_state(torch.rand, 10, device=device, dtype=dtype)
314*da0073e9SAndroid Build Coastguard Worker        rng_state2, res2 = torch._prims.rng_prims.run_and_save_rng_state(torch.rand, 10, device=device, dtype=dtype)
315*da0073e9SAndroid Build Coastguard Worker
316*da0073e9SAndroid Build Coastguard Worker        res3 = torch._prims.rng_prims.run_with_rng_state(rng_state1, torch.rand, 10, device=device, dtype=dtype)
317*da0073e9SAndroid Build Coastguard Worker        res4 = torch._prims.rng_prims.run_with_rng_state(rng_state2, torch.rand, 10, device=device, dtype=dtype)
318*da0073e9SAndroid Build Coastguard Worker
319*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(ref1, res1)
320*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(ref2, res2)
321*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(ref1, res3)
322*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(ref2, res4)
323*da0073e9SAndroid Build Coastguard Worker
324*da0073e9SAndroid Build Coastguard Workerclass TestPrimsBasic(TestCase):
325*da0073e9SAndroid Build Coastguard Worker    def test_torch_ops(self):
326*da0073e9SAndroid Build Coastguard Worker        r = make_tensor((2,), device='cpu', dtype=torch.float)
327*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.ops.prims.sin(r), torch.sin(r))
328*da0073e9SAndroid Build Coastguard Worker
329*da0073e9SAndroid Build Coastguard Worker        r = LoggingTensor(r)
330*da0073e9SAndroid Build Coastguard Worker        with capture_logs() as logs:
331*da0073e9SAndroid Build Coastguard Worker            log_input("input", r)
332*da0073e9SAndroid Build Coastguard Worker            prims.sin(r)
333*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline('\n'.join(logs), """\
334*da0073e9SAndroid Build Coastguard Worker$0: f32[2] = input('input')
335*da0073e9SAndroid Build Coastguard Worker$1: f32[2] = torch._ops.prims.sin.default($0)""")
336*da0073e9SAndroid Build Coastguard Worker
337*da0073e9SAndroid Build Coastguard Worker    def test_mul_complex(self):
338*da0073e9SAndroid Build Coastguard Worker        prims.mul(torch.randn(2), 1 + 1j)
339*da0073e9SAndroid Build Coastguard Worker
340*da0073e9SAndroid Build Coastguard Worker    def test_clone_complex(self):
341*da0073e9SAndroid Build Coastguard Worker        with torch._dispatch.python.enable_python_dispatcher():
342*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(4, dtype=torch.complex64, device='meta').conj()
343*da0073e9SAndroid Build Coastguard Worker            out = x + 1
344*da0073e9SAndroid Build Coastguard Worker
345*da0073e9SAndroid Build Coastguard Worker    def test_check_deprecation_warning(self):
346*da0073e9SAndroid Build Coastguard Worker        with self.assertWarnsRegex(FutureWarning, 'will be removed in the future'):
347*da0073e9SAndroid Build Coastguard Worker            torch._prims_common.check(True, lambda: 'message')
348*da0073e9SAndroid Build Coastguard Worker
349*da0073e9SAndroid Build Coastguard Worker
350*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestPrims, globals())
351*da0073e9SAndroid Build Coastguard Worker
352*da0073e9SAndroid Build Coastguard Worker
353*da0073e9SAndroid Build Coastguard Workerclass TestRefs(TestCase):
354*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float32)
355*da0073e9SAndroid Build Coastguard Worker    def test_constant_pad_nd_memory_format(self, device, dtype):
356*da0073e9SAndroid Build Coastguard Worker        # Test memory format is preserved in unambiguous cases
357*da0073e9SAndroid Build Coastguard Worker        for mf, ndim in (
358*da0073e9SAndroid Build Coastguard Worker                (torch.channels_last, 4),
359*da0073e9SAndroid Build Coastguard Worker                (torch.contiguous_format, 4),
360*da0073e9SAndroid Build Coastguard Worker                (torch.channels_last_3d, 5),
361*da0073e9SAndroid Build Coastguard Worker                (torch.contiguous_format, 5),
362*da0073e9SAndroid Build Coastguard Worker        ):
363*da0073e9SAndroid Build Coastguard Worker            a = torch.zeros([2] * ndim).to(memory_format=mf)
364*da0073e9SAndroid Build Coastguard Worker            res = refs.constant_pad_nd(a, pad=[1] * (2 * ndim))
365*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(res.is_contiguous(memory_format=mf))
366*da0073e9SAndroid Build Coastguard Worker
367*da0073e9SAndroid Build Coastguard Worker        # Ambiguous cases
368*da0073e9SAndroid Build Coastguard Worker
369*da0073e9SAndroid Build Coastguard Worker        # is_channels_last_ and is_contiguous_, results in channels_last output
370*da0073e9SAndroid Build Coastguard Worker        a = torch.empty_strided((2, 1, 2, 2), stride=(4, 1, 2, 1))
371*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(a.is_contiguous(memory_format=torch.channels_last))
372*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(a.is_contiguous())
373*da0073e9SAndroid Build Coastguard Worker        actual = refs.constant_pad_nd(a, pad=[1] * 8)
374*da0073e9SAndroid Build Coastguard Worker        expect = torch.constant_pad_nd(a, pad=[1] * 8)
375*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(actual.stride(), expect.stride())
376*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(actual.is_contiguous(memory_format=torch.channels_last))
377*da0073e9SAndroid Build Coastguard Worker
378*da0073e9SAndroid Build Coastguard Worker        # is_channels_last_contiguous_ but not is_channels_last_, results in
379*da0073e9SAndroid Build Coastguard Worker        # contiguous output
380*da0073e9SAndroid Build Coastguard Worker        a = torch.empty_strided((2, 1, 2, 2), stride=(4, 4, 2, 1))
381*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(a.is_contiguous(memory_format=torch.channels_last))
382*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(a.is_contiguous())
383*da0073e9SAndroid Build Coastguard Worker        actual = refs.constant_pad_nd(a, pad=[1] * 8)
384*da0073e9SAndroid Build Coastguard Worker        expect = torch.constant_pad_nd(a, pad=[1] * 8)
385*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(actual.stride(), expect.stride())
386*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(actual.is_contiguous())
387*da0073e9SAndroid Build Coastguard Worker
388*da0073e9SAndroid Build Coastguard Worker    def test_unbind(self):
389*da0073e9SAndroid Build Coastguard Worker        # If unbind returns empty tuple, it breaks some assumptions in some backward tests in test_ops.py.
390*da0073e9SAndroid Build Coastguard Worker        # So can't put this test into common_methods_invocations.py.
391*da0073e9SAndroid Build Coastguard Worker        a = torch.rand([3, 0, 4])
392*da0073e9SAndroid Build Coastguard Worker        actual = refs.unbind(a, 1)
393*da0073e9SAndroid Build Coastguard Worker        expect = torch.unbind(a, 1)
394*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(actual, expect)
395*da0073e9SAndroid Build Coastguard Worker
396*da0073e9SAndroid Build Coastguard Worker    def test_logspace_with_complex_input(self):
397*da0073e9SAndroid Build Coastguard Worker        actual = refs.logspace(2, 10 + 5j, steps=5)
398*da0073e9SAndroid Build Coastguard Worker        expect = torch.logspace(2, 10 + 5j, steps=5)
399*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(actual, expect)
400*da0073e9SAndroid Build Coastguard Worker
401*da0073e9SAndroid Build Coastguard Worker    def test_linspace_with_complex_input(self):
402*da0073e9SAndroid Build Coastguard Worker        actual = refs.linspace(2, 10 + 5j, steps=5)
403*da0073e9SAndroid Build Coastguard Worker        expect = torch.linspace(2, 10 + 5j, steps=5)
404*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(actual, expect)
405*da0073e9SAndroid Build Coastguard Worker
406*da0073e9SAndroid Build Coastguard Worker    # From https://github.com/pytorch/pytorch/issues/109558
407*da0073e9SAndroid Build Coastguard Worker    def test_infinite_loop_from_py_dispatcher(self):
408*da0073e9SAndroid Build Coastguard Worker        # enables prim decomps
409*da0073e9SAndroid Build Coastguard Worker        with torch._dispatch.python.enable_python_dispatcher():
410*da0073e9SAndroid Build Coastguard Worker            x = torch.ones(4)
411*da0073e9SAndroid Build Coastguard Worker            y = x.to(device="meta")
412*da0073e9SAndroid Build Coastguard Worker
413*da0073e9SAndroid Build Coastguard Worker    def test_inferred_tags(self):
414*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.ops.prims.normal.default.tags, (torch.Tag.nondeterministic_seeded, torch.Tag.pt2_compliant_tag))
415*da0073e9SAndroid Build Coastguard Worker
416*da0073e9SAndroid Build Coastguard Worker
417*da0073e9SAndroid Build Coastguard Worker
418*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestRefs, globals())
419*da0073e9SAndroid Build Coastguard Worker
420*da0073e9SAndroid Build Coastguard Worker
421*da0073e9SAndroid Build Coastguard Workerclass TestDecomp(TestCase):
422*da0073e9SAndroid Build Coastguard Worker    @ops([op for op in op_db if op.supports_varargs], dtypes=OpDTypes.any_one)
423*da0073e9SAndroid Build Coastguard Worker    def test_decomposition_method_vararg(self, device, dtype, op):
424*da0073e9SAndroid Build Coastguard Worker        # some ops have vararg variants for the methods. this tests it.
425*da0073e9SAndroid Build Coastguard Worker        # we don't have tests for varargs in OpInfo, so we need to
426*da0073e9SAndroid Build Coastguard Worker        # improvise this a bit.
427*da0073e9SAndroid Build Coastguard Worker        # The rule for general functions (the special cases being e.g. tensor
428*da0073e9SAndroid Build Coastguard Worker        # creation functions taking shapes) is that things can be vararg
429*da0073e9SAndroid Build Coastguard Worker        # if the method has only one argument of sequence type.
430*da0073e9SAndroid Build Coastguard Worker        # e.g. permute can be called on a 3d tensor t as t.permute(0, 2, 1)
431*da0073e9SAndroid Build Coastguard Worker        #      as well as t.permute([0, 2, 1])
432*da0073e9SAndroid Build Coastguard Worker        #      when the signature in native_functions.yaml
433*da0073e9SAndroid Build Coastguard Worker        #      shows arguments Tensor self, IntList dims
434*da0073e9SAndroid Build Coastguard Worker        # we might need to adjust things for the factory functions or
435*da0073e9SAndroid Build Coastguard Worker        # have them do their own test
436*da0073e9SAndroid Build Coastguard Worker        from torch.fx.experimental.proxy_tensor import make_fx
437*da0073e9SAndroid Build Coastguard Worker        from torch._prims.context import TorchRefsMode
438*da0073e9SAndroid Build Coastguard Worker
439*da0073e9SAndroid Build Coastguard Worker        # filter out empty tuple as that cannot be the varargs
440*da0073e9SAndroid Build Coastguard Worker        sample_inputs = (si for si in op.sample_inputs(device, dtype, requires_grad=False)
441*da0073e9SAndroid Build Coastguard Worker                         if (si.args[-1] if si.args else si.input))
442*da0073e9SAndroid Build Coastguard Worker
443*da0073e9SAndroid Build Coastguard Worker        # just run one test, we assume there is a suitable one in the tests
444*da0073e9SAndroid Build Coastguard Worker        sample_input = next(sample_inputs)
445*da0073e9SAndroid Build Coastguard Worker        all_args = (sample_input.input,) + sample_input.args
446*da0073e9SAndroid Build Coastguard Worker
447*da0073e9SAndroid Build Coastguard Worker        # in general, the methods take varargs and not (always?) the function
448*da0073e9SAndroid Build Coastguard Worker        # variants, the exception to this rule are the factory functions
449*da0073e9SAndroid Build Coastguard Worker        if op.is_factory_function:
450*da0073e9SAndroid Build Coastguard Worker            fn = op.op
451*da0073e9SAndroid Build Coastguard Worker        else:
452*da0073e9SAndroid Build Coastguard Worker            fn = op.method_variant
453*da0073e9SAndroid Build Coastguard Worker        with TorchRefsMode():
454*da0073e9SAndroid Build Coastguard Worker            gm = make_fx(fn)(*all_args[:-1], *all_args[-1])
455*da0073e9SAndroid Build Coastguard Worker
456*da0073e9SAndroid Build Coastguard Worker        # in case we add random factory functions
457*da0073e9SAndroid Build Coastguard Worker        torch.manual_seed(1)
458*da0073e9SAndroid Build Coastguard Worker        res = gm(*all_args[:-1], *all_args[-1])
459*da0073e9SAndroid Build Coastguard Worker        torch.manual_seed(1)
460*da0073e9SAndroid Build Coastguard Worker        expected = fn(*all_args[:-1], *all_args[-1])
461*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res, expected)
462*da0073e9SAndroid Build Coastguard Worker
463*da0073e9SAndroid Build Coastguard Worker
464*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestDecomp, globals())
465*da0073e9SAndroid Build Coastguard Worker
466*da0073e9SAndroid Build Coastguard Worker
467*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
468*da0073e9SAndroid Build Coastguard Worker    run_tests()
469