xref: /aosp_15_r20/external/pytorch/test/test_meta.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: decompositions"]
2
3import itertools
4import torch
5import os
6import numpy as np
7from enum import Enum
8from torch.overrides import resolve_name
9from torch.utils._pytree import tree_map, tree_flatten, tree_unflatten
10from torch.utils import _pytree as pytree
11from torch._subclasses.meta_utils import MetaConverter, assert_metadata_eq, is_sparse_any
12import torch.utils._python_dispatch
13from torch._dispatch.python import enable_python_dispatcher
14from torch._ops import OpOverload, OpOverloadPacket
15from torch.testing import make_tensor
16from torch.testing._internal.common_utils import unMarkDynamoStrictTest
17from torch.testing._internal.common_utils import (
18    TestCase,
19    skipIfCrossRef,
20    skipIfTorchDynamo,
21    suppress_warnings,
22    TEST_WITH_ASAN,
23    TEST_WITH_TORCHDYNAMO,
24    run_tests,
25    dtype_abbrs,
26    parametrize
27)
28from torch.testing._internal.common_device_type import (
29    ops,
30    instantiate_device_type_tests,
31    onlyCUDA,
32    onlyCPU,
33    OpDTypes,
34)
35from torch.testing._internal.common_methods_invocations import (
36    binary_ufuncs, op_db, foreach_unary_op_db, foreach_binary_op_db,
37    foreach_pointwise_op_db, foreach_reduce_op_db, foreach_other_op_db)
38from torch.testing._internal.opinfo.core import S, SampleInput
39from torchgen.yaml_utils import YamlLoader
40from torchgen.model import OperatorName
41
42import copy
43import sys
44import yaml
45import atexit
46import re
47from collections import defaultdict
48from collections.abc import Iterable
49import unittest
50import warnings
51import weakref
52from functools import partial, wraps
53
54bf16 = torch.bfloat16
55f64 = torch.float64
56f32 = torch.float32
57f16 = torch.float16
58c32 = torch.complex32
59c64 = torch.complex64
60c128 = torch.complex128
61i8 = torch.int8
62i16 = torch.int16
63i32 = torch.int32
64i64 = torch.int64
65b8 = torch.bool
66u8 = torch.uint8
67u16 = torch.uint16
68u32 = torch.uint32
69u64 = torch.uint64
70
71foreach_op_db = (
72    foreach_unary_op_db +
73    foreach_binary_op_db +
74    foreach_pointwise_op_db +
75    foreach_reduce_op_db +
76    foreach_other_op_db
77)
78
79
80class TestMetaConverter(TestCase):
81    def assertSameVersionCounter(self, m1, m2):
82        # Cannot easily test m1 and m2 have same storage due to
83        # lack of Storage bindings.  Use version counter.
84        vc = m1._version
85        self.assertEqual(m2._version, vc)
86        # Doing it this way ensures that we get VC bump even with leaves
87        with torch.no_grad():
88            m1._base.add_(3)
89        self.assertNotEqual(m1._version, vc)
90        self.assertEqual(m2._version, m1._version)
91
92    def assertMetadataMatches(self, m1, m2):
93        assert_metadata_eq(self.assertEqual, m1, m2)
94
95    def test_view_of_non_leaf(self):
96        x = torch.randn(4, requires_grad=True)
97        y = x.neg()
98        z1 = y[:]
99        z2 = y[:]
100        to_meta = MetaConverter()
101        m1 = to_meta(z1)
102        m2 = to_meta(z2)
103
104        # check the test is actually testing what it claims
105        self.assertTrue(m1._is_view())
106        self.assertFalse(m1._base.is_leaf)
107
108        self.assertIsNot(m1, m2)
109        self.assertMetadataMatches(m1, z1)
110        self.assertMetadataMatches(m2, z2)
111        self.assertSameVersionCounter(m1, m2)
112
113    def test_view_of_leaf(self):
114        x = torch.randn(4, requires_grad=True)
115        z1 = x[:]
116        z2 = x[:]
117        to_meta = MetaConverter()
118        m1 = to_meta(z1)
119        m2 = to_meta(z2)
120
121        # check the test is actually testing what it claims
122        self.assertTrue(m1._is_view())
123        self.assertTrue(m1._base.is_leaf)
124
125        self.assertIsNot(m1, m2)
126        self.assertMetadataMatches(m1, z1)
127        self.assertMetadataMatches(m2, z2)
128        self.assertSameVersionCounter(m1, m2)
129
130    def test_view_of_view_of_leaf(self):
131        x = torch.randn(8)
132        y = x.view(2, 4)
133        y.requires_grad = True
134        z = y.view(2, 2, 2)
135
136        to_meta = MetaConverter()
137        mx = to_meta(x)
138        mz = to_meta(z)
139
140        self.assertFalse(z.is_leaf)
141
142        self.assertMetadataMatches(mx, x)
143        self.assertMetadataMatches(mz, z)
144
145    def test_leaf(self):
146        x = torch.randn(4, requires_grad=True)
147        to_meta = MetaConverter()
148        m = to_meta(x)
149
150        # check the test is actually testing what it claims
151        self.assertTrue(m.is_leaf)
152        self.assertTrue(m.requires_grad)
153
154        self.assertMetadataMatches(m, x)
155
156    def test_non_leaf(self):
157        x = torch.randn(4, requires_grad=True)
158        y = x.neg()
159        to_meta = MetaConverter()
160        m = to_meta(y)
161
162        # check the test is actually testing what it claims
163        self.assertFalse(m.is_leaf)
164        self.assertTrue(m.requires_grad)
165
166        self.assertMetadataMatches(m, y)
167
168    def test_requires_grad_false(self):
169        x = torch.randn(4, requires_grad=False)
170        to_meta = MetaConverter()
171        m = to_meta(x)
172
173        # check the test is actually testing what it claims
174        self.assertFalse(m.requires_grad)
175
176        self.assertMetadataMatches(m, x)
177
178    def test_channels_last(self):
179        x = torch.empty(2, 3, 4, 5, memory_format=torch.channels_last)
180        to_meta = MetaConverter()
181        m = to_meta(x)
182
183        # check the test is actually testing what it claims
184        self.assertTrue(m.is_leaf)
185
186        self.assertMetadataMatches(m, x)
187
188    def test_channels_last_leaf(self):
189        x = torch.empty(2, 3, 4, 5, memory_format=torch.channels_last, requires_grad=True)
190        to_meta = MetaConverter()
191        m = to_meta(x)
192
193        # check the test is actually testing what it claims
194        self.assertTrue(m.requires_grad)
195        self.assertTrue(m.is_leaf)
196
197        self.assertMetadataMatches(m, x)
198
199    def test_channels_last_non_leaf(self):
200        x = torch.empty(2, 3, 4, 5, memory_format=torch.channels_last, requires_grad=True)
201        y = x + 2
202
203        # sanity
204        self.assertEqual(x.stride(), y.stride())
205        self.assertFalse(y.is_leaf)
206
207        to_meta = MetaConverter()
208        m = to_meta(y)
209
210        # check the test is actually testing what it claims
211        self.assertTrue(m.requires_grad)
212        self.assertFalse(m.is_leaf)
213
214        self.assertMetadataMatches(m, y)
215
216        # Check that we can autograd with m as input without erroring;
217        # see https://github.com/pytorch/pytorch/issues/87956
218        loss = m.sum()
219        torch.autograd.grad(loss, m)
220
221    def test_empty_strided_non_dense_leaf(self):
222        x = torch.empty_strided((2, 2), (4, 2), requires_grad=True)
223
224        to_meta = MetaConverter()
225        m = to_meta(x)
226
227        # check the test is actually testing what it claims
228        self.assertTrue(m.requires_grad)
229        self.assertTrue(m.is_leaf)
230
231        self.assertMetadataMatches(m, x)
232
233    def test_view_mutate(self):
234        x = torch.zeros(4)
235        y = x.view(2, 2)
236
237        to_meta = MetaConverter()
238        m = to_meta(y)
239
240        y.add_(torch.randn(2, 2, requires_grad=True))
241        m.add_(torch.randn(2, 2, device='meta', requires_grad=True))
242
243    def test_non_leaf_torture(self):
244        x = torch.empty(20, requires_grad=True)
245        with torch.no_grad():
246            x.set_(x.storage(), 10, (2,), (2,))
247
248        to_meta = MetaConverter()
249        m = to_meta(x)
250
251        # check the test is actually testing what it claims
252        self.assertTrue(m.requires_grad)
253        self.assertTrue(m.is_leaf)
254
255        self.assertMetadataMatches(m, x)
256
257    # NB: complex stuff is not actually exercised right now because
258    # we have a blanket exclusion for complex conversion
259
260    def test_view_as_real(self):
261        x = torch.randn(4, dtype=torch.complex64)
262        y = torch.view_as_real(x)
263        m = MetaConverter()(y)
264        self.assertMetadataMatches(m, y)
265
266    def test_complex_noncontiguous_bug(self):
267        x = torch.randn((2, 2, 4, 9), dtype=torch.complex32)[:, 0, :, :]
268        m = MetaConverter()(x)
269        self.assertMetadataMatches(m, x)
270
271    def test_view_as_complex(self):
272        x = torch.randn((4, 2), dtype=torch.float32)
273        y = torch.view_as_complex(x)
274        m = MetaConverter()(y)
275        self.assertMetadataMatches(m, y)
276
277    def test_view_dtype(self):
278        x = torch.randn(4, dtype=torch.float32)
279        y = x.view(dtype=torch.int32)
280        m = MetaConverter()(y)
281        self.assertMetadataMatches(m, y)
282
283    def test_imag(self):
284        x = torch.randn(4, dtype=torch.complex64)
285        y = x.imag
286        m = MetaConverter()(y)
287        self.assertMetadataMatches(m, y)
288
289    def test_inplace_set_storage(self):
290        x = torch.tensor([0, 1], dtype=torch.int64)
291        storage = x.untyped_storage()
292        ssize = storage.size()
293        meta = torch.empty((), dtype=torch.int64)
294        meta.set_(storage, 0, (), ())
295        self.assertEqual(storage.size(), ssize)
296
297    @skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1991")
298    def test_weakref(self):
299        x = torch.randn(4, 4, 4)
300        m = MetaConverter()
301        y = m(x)
302        z = m(x)
303        self.assertIs(y, z)
304        self.assertEqual(len(m.tensor_memo), 1)
305        self.assertEqual(len(m.storage_memo), 1)
306        self.assertEqual(len(m.describer.lookup_tensor), 1)
307        self.assertEqual(len(m.describer.lookup_storage), 1)
308        del x
309        # Entries from Tensor -> int get deallocated when the real tensor
310        # disappears...
311        self.assertEqual(len(m.describer.lookup_tensor), 0)
312        self.assertEqual(len(m.describer.lookup_storage), 0)
313        del y
314        del z
315        # ... but the int -> FakeTensor entries don't die until the fake
316        # tensors themselves die (because the user may have held onto the
317        # int key and are expecting to get a consistent fake tensor in
318        # this case)
319        self.assertEqual(len(m.tensor_memo), 0)
320        self.assertEqual(len(m.storage_memo), 0)
321        li = []
322        r = []
323        for i in range(4):
324            li.append(torch.rand([i]))
325            r.append(m(li[-1]))
326        self.assertEqual(len(m.tensor_memo), 4)
327        self.assertEqual(len(m.storage_memo), 4)
328        self.assertEqual(len(m.describer.lookup_tensor), 4)
329        self.assertEqual(len(m.describer.lookup_storage), 4)
330        del li
331        self.assertEqual(len(m.describer.lookup_tensor), 0)
332        self.assertEqual(len(m.describer.lookup_storage), 0)
333        del r
334        self.assertEqual(len(m.tensor_memo), 0)
335        self.assertEqual(len(m.storage_memo), 0)
336
337    @skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1991")
338    def test_tensor_outlives_converter(self):
339        m = MetaConverter()
340        ref = weakref.ref(m)
341        x = torch.randn([4, 4])
342        y = m(x)
343        del m
344        self.assertIs(ref(), None)
345
346aten = torch.ops.aten
347
348CHECK_STRIDES = {
349    torch.Tensor.__getitem__,
350}
351
352CHECK_ALL_STRIDES = {
353    aten.unsqueeze.default
354}
355
356CHECK_STRIDES_SKIPS = {
357    aten._conj_physical.default,
358    aten._fft_c2c.default,
359    aten._fft_c2r.default,
360    aten._fft_r2c.default,
361    aten._linalg_svd.default,
362    aten.binary_cross_entropy.default,
363    aten.complex.default,
364    aten.polar.default,
365    aten.copysign.Tensor,
366    aten.div.Tensor_mode,
367    aten.floor_divide.default,
368    aten.heaviside.default,
369    aten.lerp.Scalar,
370    aten.lerp.Tensor,
371    aten.logaddexp.default,
372    aten.logical_and.default,
373    aten.logical_or.default,
374    aten.logical_xor.default,
375    aten.pow.Scalar,
376    aten.prelu.default,
377    aten.special_xlog1py.default,
378    aten.xlogy.Tensor,
379    aten.nll_loss2d_forward.default,
380
381    # channel_last and channel_last_3d related failures
382    aten.convolution.default,
383
384    # following ops fails if include_storage_offset = True, but these are a bit edge casey
385    # we should still fix them, leaving them here for tracking.
386    # aten._reshape_alias.default,  # repro with test_dispatch_symbolic_meta_outplace_all_strides_matmul_cuda_float32
387    # aten.view.default,  # repro with test_dispatch_symbolic_meta_outplace_all_strides_unflatten_cuda_float32
388}
389
390CHECK_CONJ_SKIPS = {
391    # The conj bit is not copied, see:
392    # https://github.com/pytorch/pytorch/pull/101836
393    aten.linalg_lu_solve.out,
394}
395
396class CheckStrides(Enum):
397    NONE = 0
398    SIGNIFICANT = 1
399    ALL = 2
400
401def should_check_strides(func):
402    if func in CHECK_ALL_STRIDES:
403        return CheckStrides.ALL
404    if func in CHECK_STRIDES:
405        return CheckStrides.SIGNIFICANT
406    if func in CHECK_STRIDES_SKIPS:
407        return CheckStrides.NONE
408    if not isinstance(func, torch._ops.OpOverload):
409        return CheckStrides.NONE
410    # Prims are expected to model strides correctly
411    if func.namespace == "prims":
412        return CheckStrides.SIGNIFICANT
413    # Check if it's a view, by testing if any of the returns have
414    # a non-empty alias set
415    if any(r.alias_info.before_set for r in func._schema.returns if r.alias_info):
416        return CheckStrides.SIGNIFICANT
417    # TODO: check for TensorIterator
418    return CheckStrides.SIGNIFICANT
419
420def assert_ref_meta_equal(test_case, func, meta_rs, rs, msg_callable):
421    flat_meta_rs = pytree.tree_leaves(meta_rs)
422    flat_rs = pytree.tree_leaves(rs)
423    test_case.assertEqual(len(flat_meta_rs), len(flat_rs))
424    for i, meta_r, r in zip(range(len(flat_rs)), flat_meta_rs, flat_rs):
425        def test_assert(cond, msg):
426            if not cond:
427                raise RuntimeError(f"output {i}: {msg_callable(msg)}")
428        if not isinstance(r, torch.Tensor):
429            continue
430        test_assert(isinstance(meta_r, torch.Tensor), f"but real {i}th result is Tensor")
431        test_assert(meta_r.dtype == r.dtype, f"for element {i}, was {meta_r.dtype} but real dtype was {r.dtype}")
432        test_assert(meta_r.shape == r.shape, f"for element {i}, was {meta_r.shape} but real shape was {r.shape}")
433        # See https://github.com/pytorch/pytorch/issues/78050
434        if should_check_strides(func) == CheckStrides.ALL:
435            same_strides, _ = torch._prims_common.check_all_strides(meta_r, r)
436            test_assert(same_strides, f"for element {i}, was {meta_r.stride()} but real stride was {r.stride()}")
437        elif should_check_strides(func) == CheckStrides.SIGNIFICANT:
438            same_strides, _ = torch._prims_common.check_significant_strides(meta_r, r)
439            test_assert(same_strides, f"for element {i}, was {meta_r.stride()} but real stride was {r.stride()}")
440        test_assert(
441            meta_r.storage_offset() == r.storage_offset(),
442            f"for element {i}, was {meta_r.storage_offset()} but real storage_offset was {r.storage_offset()}")
443        test_assert(meta_r.requires_grad == r.requires_grad,
444                    f"for element {i}, was {meta_r.requires_grad} but real requires_grad was {r.requires_grad}")
445        if func not in CHECK_CONJ_SKIPS:
446            test_assert(meta_r.is_conj() == r.is_conj(),
447                        f"for element {i}, was {meta_r.is_conj()} but real is_conj was {r.is_conj()}")
448        test_assert(meta_r.is_neg() == r.is_neg(), f"for element {i}, was {meta_r.is_neg()} but real is_neg was {r.is_neg()}")
449
450
451# This environment variable controls whether or not we print expected failure
452# lists at the end of a test suite run.  The intended usage looks like this:
453#
454# 1. Run `PYTORCH_COLLECT_EXPECT=1 python test/test_meta.py` on a CUDA build
455#    of PyTorch that has LAPACK/MAGMA installed.  You can filter `-k test_meta`
456#    or `-k test_dispatch_meta` to only focus on one or another list
457# 2. Given the printed skip/xfail list, add them to the corresponding lists;
458#    torch.* entries go in meta_function and aten.* entries go in meta_dispatch.
459#    If there are preexisting entries, you need to merge in the entries.
460#
461# This is somewhat manual but typically you shouldn't need to do this, unless
462# you've made a major change (e.g., added a new dtype to PyTorch) and need to
463# refresh the lists.  If you want to do it from scratch, just clear out the
464# preexisting lists before running.
465#
466# WARNING: Python dict literals will silently ignore duplicate keys
467COLLECT_EXPECT = os.getenv('PYTORCH_COLLECT_EXPECT', '0') == '1'
468
469seen_succeeded = {}
470seen_failed = {}
471failed_reasons = defaultdict(set)
472def print_seen():
473    expected_failures = []
474    skips = []
475
476    def fmt_dtypes(dtypes):
477        r = ', '.join(sorted(dtype_abbrs[d] for d in dtypes))
478        return '{' + r + '}'
479
480    for op, failed_dtypes in seen_failed.items():
481        ops = resolve_name(op)
482        succeeded_dtypes = seen_succeeded.get(op, set())
483        expected_failures_dtypes = failed_dtypes - succeeded_dtypes
484        skips_dtypes = failed_dtypes & succeeded_dtypes
485        reasons = ""
486        if failed_reasons[op]:
487            reasons = "  # " + ", ".join(sorted(failed_reasons[op]))
488        if expected_failures_dtypes:
489            expected_failures.append(f"    {ops}: {fmt_dtypes(expected_failures_dtypes)},{reasons}")
490        if skips_dtypes:
491            skips.append(f"    {ops}: {fmt_dtypes(skips_dtypes)},")
492    expected_failures.sort()
493    skips.sort()
494    nl = '\n'
495    print(f"""\
496expected_failures = {{
497{nl.join(expected_failures)}
498}}
499
500skips = {{
501{nl.join(skips)}
502}}
503""")
504if COLLECT_EXPECT:
505    atexit.register(print_seen)
506
507# Success forces pass; failure forces fail; skip unconditionally skips testing
508TestExpect = Enum("TestExpect", ("SUCCESS", "XFAILURE", "SKIP"))
509
510# unlike print produce strides
511def verbose_print(e):
512    class Lit:
513        def __init__(self, s):
514            self.s = s
515
516        def __repr__(self):
517            return self.s
518
519    def go(t):
520        if is_sparse_any(t):
521            return t
522        elif isinstance(t, torch.Tensor):
523            return Lit(f"{t} stride={t.stride()}")
524        else:
525            return t
526
527    return repr(tree_map(go, e))
528
529def run_meta_crossref(
530    test_case,
531    test_expect,
532    func,
533    args,
534    kwargs,
535    *,
536    dtype,
537    device_type,
538    run_symbolic_meta: bool
539):
540    to_meta = MetaConverter()
541    do_meta = test_expect is not TestExpect.SKIP
542    if do_meta:
543        try:
544            meta_args = tree_map(to_meta, args)
545            meta_kwargs = tree_map(to_meta, kwargs)
546        except Exception as e:
547            raise RuntimeError(
548                f"failed to convert args to meta; "
549                f"originally (*{args}, **{kwargs})") from e
550    try:
551        rs = func(*args, **kwargs)
552    except Exception as e:
553        raise AssertionError("Original OpInfo is broken") from e
554
555    # TODO: also handle cases where func raise an exception
556
557    # For now, only attempt if we managed to convert all tensor types
558    # (if any of them failed, we're in a mixed device situation and
559    # this isn't well supported)
560    if do_meta and to_meta.successful():
561        # Special cases
562        if func is torch.tensor_split:
563            # Use original indices_or_sections, this argument is data dependent
564            meta_args = (meta_args[0], args[1]) + meta_args[2:]
565        elif func is torch.Tensor.__getitem__:
566            # Ensure boolean tensors use original
567            assert len(args) == 2
568            flat_args = pytree.tree_leaves(args[1])
569            flat_meta_args, spec = tree_flatten(meta_args[1])
570            flat_new_args = []
571            for a, ma in zip(flat_args, flat_meta_args):
572                flat_new_args.append(a if isinstance(a, torch.Tensor) and a.dtype in [torch.int8, torch.bool] else ma)
573            meta_args = (meta_args[0], tree_unflatten(flat_new_args, spec))
574        elif func in (torch.ops.aten.repeat_interleave.Tensor, torch.ops.aten.repeat_interleave.Tensor_out):
575            if kwargs.get("output_size", None) is None:
576                meta_args = args
577            if func is torch.ops.aten.repeat_interleave.Tensor_out:
578                meta_kwargs["out"] = kwargs["out"]
579        elif func in (torch.ops.aten.index.Tensor, torch.ops.aten.index.Tensor_out):
580            # Don't convert boolean tensors to meta as they will have nonzero
581            # called on them
582            indices = []
583            for meta_index, real_index in zip(meta_args[1], args[1]):
584                if meta_index is not None and meta_index.dtype in [torch.int8, torch.bool]:
585                    indices.append(real_index)
586                else:
587                    indices.append(meta_index)
588            meta_args = (meta_args[0], indices)
589        elif func is torch.nn.functional.ctc_loss and all([isinstance(args[2], list), isinstance(args[3], list)]):
590            # torch.ops.aten._ctc_loss.IntList has a meta kernel but
591            # torch.ops.aten._ctc_loss.Tensor does not
592            test_expect = TestExpect.SUCCESS
593
594        if kwargs.get("device", None) is not None:
595            meta_kwargs["device"] = "meta"
596
597        try:
598            # Suppress warnings, this doesn't matter for test_meta.py
599            # but it does matter if you want to use this decorator
600            # for cross-ref testing, as some tests may be looking at
601            # errors
602            with warnings.catch_warnings():
603                warnings.simplefilter("ignore")
604                if run_symbolic_meta:
605                    # Run the decomps and meta kernels registered
606                    # to the python dispatcher instead of the regular dispatcher.
607                    # This should be the same set of kernels
608                    # that fake tensor runs in dynamic shapes mode.
609                    with enable_python_dispatcher():
610                        meta_rs = func(*meta_args, **meta_kwargs)
611                else:
612                    meta_rs = func(*meta_args, **meta_kwargs)
613        except Exception as e:
614            if test_expect is TestExpect.XFAILURE:
615                return rs
616            seen_failed.setdefault(func, set()).add(dtype)
617            if isinstance(e, NotImplementedError):
618                m = RE_NOT_IMPLEMENTED_MSG.search(e.args[0])
619                if m:
620                    failed_reasons[func].add(m.group(1))
621            if COLLECT_EXPECT:
622                return rs
623            raise RuntimeError(f"""\
624failed to run: {resolve_name(func)}(
625*{verbose_print(meta_args)},
626**{verbose_print(meta_kwargs)}
627)""") from e
628        else:
629            try:
630                delim = ',\n  '
631                assert_ref_meta_equal(test_case, func, meta_rs, rs, lambda msg: f"""\
632meta disagrees with real impl:
633{resolve_name(func)}(
634  {delim.join(map(verbose_print, meta_args))},
635  {delim.join(k + ": " + verbose_print(v) for k, v in meta_kwargs.items())}
636) = (
637  {verbose_print(meta_rs)}
638)
639{msg}
640""")
641            except Exception:
642                if test_expect is TestExpect.XFAILURE:
643                    return rs
644                seen_failed.setdefault(func, set()).add(dtype)
645                if COLLECT_EXPECT:
646                    return rs
647                raise
648            else:
649                seen_succeeded.setdefault(func, set()).add(dtype)
650                if test_expect is TestExpect.XFAILURE and not COLLECT_EXPECT:
651                    raise RuntimeError(f"unexpected success {resolve_name(func)} {meta_args} {meta_kwargs}")
652
653    return rs
654
655
656
657RE_NOT_IMPLEMENTED_MSG = re.compile(r"Could not run '([^']+)' with arguments ")
658
659meta_function_expected_failures = {
660    torch.Tensor.to_sparse : {f64, i32, c128, i64, i16, f16, u8, c64, bf16, b8, i8, f32},
661    torch.allclose : {f64, f16, c128, c64, bf16, f32},
662    torch.argwhere : {f64, i32, c128, i64, i16, f16, u8, c64, bf16, b8, i8, f32},
663    torch.combinations : {f64, i32, c128, i64, i16, f16, u8, c64, bf16, b8, i8, f32},
664    torch.corrcoef : {f64, i32, c128, i64, i16, u8, c64, bf16, f16, i8, f32},
665    torch.cov : {f64, i32, c128, i64, i16, u8, c64, bf16, i8, f32, f16},
666    torch.functional.istft : {f64, c64, c128, f32},
667    torch.geqrf : {f64, c64, c128, f32},
668    torch.masked_select : {f64, i32, c128, i64, i16, f16, u8, c64, bf16, b8, i8, f32},
669    torch.nonzero : {f64, i32, c128, i64, i16, c32, f16, u8, c64, bf16, b8, i8, f32},
670    torch.Tensor.nonzero : {f64, i32, c128, i64, i16, c32, f16, u8, c64, bf16, b8, i8, f32},
671    torch.Tensor.item : {f64, i32, c128, i64, i16, f16, u8, c32, c64, bf16, b8, i8, f32},
672    torch.bincount : {i32, i64, u8, i16, i8},
673    torch.functional.unique : {f64, i32, i64, u8, i16, f16, bf16, b8, i8, f32, u16, u32, u64},
674    torch.functional.unique_consecutive : {f64, i32, i64, u8, i16, f16, bf16, b8, i8, f32, u16, u32, u64},
675    torch.histogram : {f64, f32},
676    torch.histogramdd : {f64, f32},
677    torch.nn.functional.ctc_loss : {f64, f32},
678    torch.nn.functional.gaussian_nll_loss : {f16, f64, bf16, f32},
679    torch.linalg.lstsq : {f64, f32, c128, c64},
680}
681
682meta_function_expected_failures_conditional = {
683    torch.repeat_interleave : (lambda dtype, *args, **kwargs: not isinstance(kwargs.get("repeats", None), int)),
684}
685
686"""
687# This is some sample code for how we could dump these dicts into YAML
688# file for easier reading/writing
689import yaml
690print(yaml.dump(
691  {resolve_name(k): [dtype_abbrs[d] for d in v]
692   for k, v in meta_function_expected_failures.items()}, default_flow_style=None))
693import sys
694sys.exit()
695"""
696
697meta_function_skips = {
698    torch.Tensor.__rmatmul__ : {bf16, c128, f64, f32, f16, c64},
699    torch.Tensor.matmul : {f64, f32, c128, c64},
700    torch.functional.atleast_2d : {bf16, i8, c32, i64, u8, c128, b8, f64, i16, i32, f32, f16, c64},
701    torch.functional.atleast_3d : {bf16, i8, c32, i64, u8, c128, b8, f64, i16, i32, f32, f16, c64},
702    torch.functional.cartesian_prod : {bf16, i8, i64, u8, c128, b8, f64, i16, i32, f32, f16, c64},
703    torch.functional.einsum : {bf16, c128, f64, f32, f16, c64},
704    torch.inner : {f16, bf16, i8, i64, u8, c128, f64, i16, f32, i32, c64},
705    torch.linalg.matrix_norm : {c128, f32, c64, f64},
706    torch.linalg.matrix_rank : {c128, c64},
707    torch.linalg.svd : {c128, c64},
708    torch.matmul : {bf16, c128, f64, f32, f16, c64},
709    torch.nanquantile : {f64, f32},
710    torch.narrow : {bf16, i8, i64, u8, c128, b8, f64, i16, i32, f32, f16, c32, c64},
711    torch.nn.functional.batch_norm : {f64, f32},
712    torch.nn.functional.binary_cross_entropy : {bf16, f64, f32, f16},
713    torch.nn.functional.dropout3d : {bf16, f64, f32, f16},
714    torch.nn.functional.local_response_norm : {bf16, f64, f32, f16},
715    torch.svd : {c128, c64},
716    torch.take_along_dim : {bf16, i8, i64, u8, c128, b8, f64, i16, i32, f32, f16, c64},
717    torch.vstack : {bf16, i8, c32, i64, u8, c128, b8, f64, i16, i32, f32, f16, c64},
718    torch.diff : {b8},
719    torch.equal : {bf16, i8, c32, i64, u8, c128, b8, f64, i16, i32, f32, f16, c64},
720    torch.nanmean : {bf16, f64, f32, f16, c32, c64, c128},
721    torch.nn.functional.cross_entropy : {bf16, f64, f32},
722    torch.nn.functional.nll_loss : {bf16, f64, f32},
723    torch.linalg.cond : {c128, c64, f32, f64},
724    torch.linalg.vecdot : {bf16, f64, f32, f16},
725    torch.empty : {bf16, i8, c32, i64, u8, c128, b8, f64, i16, i32, f32, f16, c64},
726    torch.Tensor.addbmm_: {bf16, c128, c64, f32, f64, i16, i32, i64, i8, u8},
727    torch.nn.functional.one_hot : {i64},
728}
729
730
731meta_function_device_expected_failures = defaultdict(dict)
732meta_function_device_expected_failures_only_outplace = defaultdict(dict)
733meta_function_device_skips = defaultdict(dict)
734
735meta_function_device_expected_failures['cpu'] = {
736    # TODO: The decomps for these batch norm ops return different dtypes depending
737    # on the device. We should make this work better with meta tensors.
738    torch.native_batch_norm: {bf16, f16},
739    torch._native_batch_norm_legit: {bf16, f16},
740    torch.ops.aten._batch_norm_with_update: {bf16, f16},
741    torch.native_layer_norm: {bf16, f16},
742}
743
744meta_function_device_expected_failures['cuda'] = {
745    torch.corrcoef: {bf16, f16},  # aten::_local_scalar_dense
746    torch.cov: {f16},  # aten::_local_scalar_dense
747    torch.functional.unique: {f16},  # aten::_unique2, aten::unique_dim
748    torch.functional.unique_consecutive: {f16},  # aten::unique_consecutive
749    torch.geqrf: {f32, f64},  # aten::geqrf
750}
751
752meta_function_device_skips['cpu'] = {
753    # TODO: The decomps for these batch norm ops return different dtypes depending
754    # on the device. We should make this work better with meta tensors.
755    torch.native_batch_norm: {f32, f64},
756    torch._native_batch_norm_legit: {f32, f64},
757    torch.ops.aten._batch_norm_with_update: {f32, f64},
758}
759
760meta_function_device_skips['cuda'] = {
761    torch.inner: {f16},
762    torch.linalg.matrix_rank: {f32, f64},
763    torch.linalg.svd: {f32, f64},
764    torch.nn.functional.cross_entropy: {f16},
765    torch.nn.functional.interpolate: {f16},
766    torch.nn.functional.nll_loss: {f16},
767    torch.svd: {f32, f64},
768}
769
770# This is a __torch_function__ mode that, when enabled, interposes every
771# Torch API call and runs the operator as normal, and then reruns it
772# with meta inputs, and then checks that everything about the output agrees.
773# Most of the logic deals with faithfully replicating the original tensor
774# as a meta tensor, which is nontrivial because there are a lot of subsystems
775# that may potentially be exercised.
776#
777# That being said, this class is a little overkill for what it is doing in
778# this test file (since I could have just inlined __torch_function__ on the
779# OpInfo call, and OpInfos generally have very regular inputs), but it will be
780# useful for more comprehensive testing e.g., as seen in
781# https://github.com/pytorch/pytorch/pull/75994  The big benefit is it is
782# A LOT more efficient that torch dispatch mode (at the cost of less coverage)
783class MetaCrossRefFunctionMode(torch.overrides.TorchFunctionMode):
784    test_case: TestCase
785    device_type: str
786    dtype: torch.dtype
787
788    def __init__(self, test_case, *, device, dtype, inplace):
789        self.test_case = test_case
790        self.device_type = torch.device(device).type
791        self.dtype = dtype
792        self.inplace = inplace
793
794    def __torch_function__(self, func, types, args=(), kwargs=None):
795        kwargs = kwargs or {}
796
797        if (
798            torch.jit.is_tracing() or isinstance(func, torch.ScriptMethod) or
799            # meta converter doesn't work correctly when no_dispatch() is on, so
800            # skip running the crossref test in this case
801            torch._C._dispatch_tls_local_exclude_set().has(torch._C.DispatchKey.Python)
802        ):
803            return func(*args, **kwargs)
804
805        if self.dtype in meta_function_skips.get(func, set()):
806            test_expect = TestExpect.SKIP
807        elif self.dtype in meta_function_device_skips[self.device_type].get(func, set()):
808            test_expect = TestExpect.SKIP
809        elif self.dtype in meta_function_expected_failures.get(func, set()):
810            test_expect = TestExpect.XFAILURE
811        elif self.dtype in meta_function_device_expected_failures[self.device_type].get(func, set()):
812            test_expect = TestExpect.XFAILURE
813        elif meta_function_expected_failures_conditional.get(func, lambda *_, **__: False)(self.dtype, *args, **kwargs):
814            test_expect = TestExpect.XFAILURE
815        elif not self.inplace and \
816                self.dtype in meta_function_device_expected_failures_only_outplace[self.device_type].get(func, set()):
817            test_expect = TestExpect.XFAILURE
818        else:
819            test_expect = TestExpect.SUCCESS
820
821        return run_meta_crossref(
822            self.test_case, test_expect, func, args,
823            kwargs, dtype=self.dtype, device_type=self.device_type, run_symbolic_meta=False
824        )
825
826# these always fail
827meta_dispatch_expected_failures = {
828    aten.allclose.default: {f16, bf16, f32, f64, c64, c128},  # NotImplementedError: 'aten::_local_scalar_dense'
829    aten.geqrf.default : {c64, c128, f64, f32},
830    aten.linalg_lstsq.default : {c64, c128, f64, f32},
831    aten.masked_select.default : {c64, f16, i8, f64, c128, i64, bf16, f32, i32, b8, i16, u8},
832    aten.masked_select.out : {c64, f16, i8, f64, c128, i64, bf16, f32, i32, b8, i16, u8},
833    aten.nonzero.default : {c64, f16, i8, f64, c128, i64, bf16, f32, i32, c32, b8, i16, u8},
834    aten.nonzero.out : {c64, f16, i8, f64, c128, i64, bf16, f32, i32, c32, b8, i16, u8},
835    aten._to_sparse.default : {c64, f16, i8, f64, c128, i64, bf16, f32, i32, b8, i16, u8},
836    aten._to_sparse.sparse_dim : {c64, f16, i8, f64, c128, i64, bf16, f32, i32, b8, i16, u8},
837    aten._ctc_loss.Tensor : {f32, f64},  # Shape of second output depends on data.
838    aten._histogramdd_bin_edges.default : {f32, f64},
839    aten._histogramdd_from_bin_cts.default : {f32, f64},
840    aten._histogramdd_from_bin_tensors.default : {f32, f64},
841    aten._local_scalar_dense.default : {c32, c64, f16, i8, f64, c128, i64, bf16, f32, i32, b8, i16, u8},
842    aten._unique2.default : {i8, f64, i64, f16, bf16, f32, i32, b8, i16, u8, u16, u32, u64},
843    aten.bincount.default : {i64, i8, i32, i16, u8},
844    aten.equal.default : {c64, f16, i8, f64, c128, i64, bf16, f32, i32, b8, i16, u8},
845    aten.histogram.bin_ct : {f32, f64},
846    aten.histogram.bins_tensor : {f32, f64},
847    aten.unique_consecutive.default : {i8, f64, i64, f16, bf16, f32, i32, b8, i16, u8, u16, u32, u64},
848    aten.unique_dim.default : {i8, f64, i64, f16, bf16, f32, i32, b8, i16, u8, u16, u32, u64},
849    aten.upsample_nearest3d.vec : {bf16, f32, f64, u8},
850
851}
852
853# these sometimes pass and sometimes fail
854meta_dispatch_skips = {
855    aten.index.Tensor: {i64, bf16, f16, u8, b8, f32, i8, f64, i16, i32, c32, c64, c128},  # at::nonzero doesn't have a Meta function
856    aten._to_copy.default: {i64, bf16, f16, u8, b8, f32, i8, f64, i16, i32, c32, c64, c128},
857    aten.empty.memory_format: {b8, bf16, c128, c64, c32, f16, f32, f64, i16, i32, i64, i8, u8},
858    aten.addbmm_.default: {bf16, c128, c64, f32, f64, i16, i32, i64, i8, u8},
859}
860
861# For CompositeImplicitAutograd functions that fail before hitting the Mode
862meta_dispatch_early_skips = set({
863    torch.Tensor.float_power_,
864    # Errors out in one of the tests, while ProxyTensor passes...
865    torch.Tensor.cumprod_,
866    torch.Tensor.cumsum_,
867})
868
869meta_inplace_skips = set({
870    # Errors out in one of the tests, while ProxyTensor passes...
871    torch.Tensor.cumprod_,
872    torch.Tensor.cumsum_,
873})
874
875meta_dispatch_device_expected_failures = defaultdict(dict)
876meta_dispatch_device_skips = defaultdict(dict)
877
878meta_dispatch_device_expected_failures['cpu'] = {
879    # TODO: The decomps for these batch norm ops return different dtypes depending
880    # on the device. We should make this work better with meta tensors.
881    aten.native_batch_norm.default: {bf16, f16},
882    aten._native_batch_norm_legit.default: {bf16, f16},
883    aten._native_batch_norm_legit.no_stats: {bf16, f16},
884    aten._batch_norm_with_update.default: {bf16, f16},
885
886    aten.native_layer_norm.default: {bf16, f16},
887}
888
889meta_dispatch_device_expected_failures['cuda'] = {
890    aten._unique2.default: {f16},  # aten::_unique2
891    aten._use_cudnn_ctc_loss.default: {f32, f64},  # aten::_use_cudnn_ctc_loss
892    aten._use_cudnn_ctc_loss.Tensor: {f32, f64},  # aten::_use_cudnn_ctc_loss.Tensor
893    aten.cudnn_grid_sampler.default: {f16, f32, f64},  # aten::cudnn_grid_sampler
894    aten.geqrf.default: {f32, f64},  # aten::geqrf
895    aten.linalg_eigvalsh.out: {f32, f64},  # aten::linalg_eigvalsh.out
896    aten.log_sigmoid_forward.default: {bf16, f16, f64, f32},
897    aten.log_sigmoid_forward.output : {bf16, f16, f64, f32},  # aten::log_sigmoid_forward.output
898    aten.unique_consecutive.default: {f16},  # aten::unique_consecutive
899    aten.unique_dim.default: {f16},  # aten::unique_dim
900    aten.upsample_nearest3d.vec: {f16},  # aten::upsample_nearest3d.vec
901}
902
903meta_dispatch_device_skips['cpu'] = {
904    aten._embedding_bag_forward_only.default: {bf16, f16, f32, f64},
905
906    # TODO: The decomps for these batch norm ops return different dtypes depending
907    # on the device. We should make this work better with meta tensors.
908    aten.native_batch_norm.default: {f32, f64},
909    aten._native_batch_norm_legit.default: {f32, f64},
910    aten._native_batch_norm_legit.no_stats: {f32, f64},
911    aten._batch_norm_with_update.default: {f32, f64},
912
913    # If the computation dtype is different from the input
914    # dtype this will fail. CPU execution may also have a
915    # a different output from other devices.
916    aten.native_batch_norm.out: {bf16, f16, f32, f64}
917}
918
919meta_dispatch_device_skips['cuda'] = {
920    aten._conj.default: {c32, f16},  # file issue
921    aten._linalg_svd.default: {c64, c128},  # aten::linalg_eigvalsh.out
922    aten.cudnn_batch_norm.default: {f32, f64},
923    aten.log_softmax.int : {c32, c64},
924    aten.softmax.int : {c32, c64},
925    aten.softmax.int : {c32, c64},
926
927    # ROCm stuff; technically this should be expected failure but it's
928    # not worth it; these should get unified anyway
929    aten.miopen_batch_norm.default: {f32},
930}
931
932def get_strided_args(args):
933
934    def get_strided_variants(t, include_storage_offset=False):
935        variants = []
936
937        # contiguous
938        variants.append(t)
939
940        # transposed
941        if t.ndim > 1:
942            perm = list(reversed(range(t.ndim)))
943            transposed = torch.empty(
944                t.shape[::-1], device=t.device, dtype=t.dtype, requires_grad=t.requires_grad
945            ).permute(perm).copy_(t)
946            variants.append(transposed)
947
948        # nondense
949        if t.ndim > 0:
950            nondense = torch.repeat_interleave(t, 2, dim=-1)[..., ::2]
951            variants.append(nondense)
952
953        # channel_last
954        if t.ndim == 4:
955            variants.append(t.contiguous(memory_format=torch.channels_last))
956
957        # channel_last_3d
958        if t.ndim == 5:
959            variants.append(t.contiguous(memory_format=torch.channels_last_3d))
960
961        # storage_offset
962        if include_storage_offset:
963            buffer = torch.empty(t.numel() + 1, device=t.device, dtype=t.dtype, requires_grad=t.requires_grad)
964            buffer = buffer.as_strided(t.shape, t.stride(), storage_offset=1)
965            buffer.copy_(t)
966            variants.append(buffer)
967
968        return variants
969
970    strided_args = []
971    for arg in args:
972        if isinstance(arg, torch.Tensor) and not arg.is_sparse_csr and arg.is_contiguous():
973            strided_arg_variants = get_strided_variants(arg)
974        else:
975            strided_arg_variants = [arg]
976        strided_args.append(strided_arg_variants)
977
978    yield from itertools.product(*strided_args)
979
980class MetaCrossRefDispatchMode(torch.utils._python_dispatch.TorchDispatchMode):
981    test_case: TestCase
982    device: torch.device
983    dtype: torch.dtype
984    aten_olp_no_out_overload: set = set()
985
986    def __init__(self, test_case, *, device, dtype, symbolic_meta: bool, inplace: bool, supports_out: bool):
987        self.test_case = test_case
988        # save TLS
989        self.precision = test_case.precision
990        self.rel_tol = test_case.rel_tol
991        self.device_type = torch.device(device).type
992        self.dtype = dtype
993        self.symbolic_meta = symbolic_meta
994        self.inplace = inplace
995        self.supports_out = supports_out
996
997    @staticmethod
998    def try_resolve_aten_out_overload(ol, args, kwargs, num_outputs):
999
1000        ol_args = ol._schema.arguments
1001        olp: OpOverloadPacket = ol._overloadpacket
1002
1003        if olp in MetaCrossRefDispatchMode.aten_olp_no_out_overload:
1004            return (None, None, None)
1005
1006        candidate_ols = []
1007        for candidate_ol_name in olp.overloads():
1008            candidate_ol = getattr(olp, candidate_ol_name)
1009            if any(arg.is_out for arg in candidate_ol._schema.arguments):
1010                candidate_ols.append(candidate_ol)
1011
1012        if not candidate_ols:
1013            MetaCrossRefDispatchMode.aten_olp_no_out_overload.add(olp)
1014            return (None, None, None)
1015
1016        # Now match based on args, kwargs and number of required outputs
1017        candidate_ol: OpOverload = None
1018        for candidate_ol in candidate_ols:
1019            candidate_ol_args = candidate_ol._schema.arguments
1020
1021            if (len(args) >= len(candidate_ol_args)):
1022                continue
1023
1024            # Positional arguments must have the same type
1025            if not all(
1026                ol_args[pos_arg_ind].type == candidate_ol_args[pos_arg_ind].type
1027                for pos_arg_ind in range(len(args))
1028            ):
1029                continue
1030
1031            # Number of outputs must match
1032            candidate_out_names = [out_arg.name for out_arg in candidate_ol_args[-num_outputs:] if out_arg.is_out]
1033            if len(candidate_out_names) != num_outputs:
1034                continue
1035
1036            # Now try and match kwargs. Just need to ensure that the
1037            # remaining kwargs allow an out overload to be called. For example
1038            # we can throw away parameters like `dtype` that may be passed to the
1039            # functional version of the op since the `dtype` will already be present
1040            # in the `out` argument
1041            new_kwargs = {}
1042            kwargs_match = True
1043            for arg in candidate_ol_args[len(args):-num_outputs]:
1044                if arg.name not in kwargs:
1045                    if arg.has_default_value():
1046                        new_kwargs[arg.name] = arg.default_value
1047                    elif isinstance(arg.type, torch.OptionalType):
1048                        if isinstance(arg.type.getElementType(), torch.BoolType):
1049                            new_kwargs[arg.name] = False
1050                        else:
1051                            new_kwargs[arg.name] = None
1052                    else:
1053                        kwargs_match = False
1054                        break
1055                else:
1056                    new_kwargs[arg.name] = kwargs[arg.name]
1057
1058            if kwargs_match:
1059                return candidate_ol, candidate_out_names, new_kwargs
1060
1061        return None, None, None
1062
1063    def _get_expected_test_result(self, func: OpOverload):
1064        if self.dtype in meta_dispatch_skips.get(func, set()):
1065            test_expect = TestExpect.SKIP
1066        elif self.dtype in meta_dispatch_device_skips[self.device_type].get(func, set()):
1067            test_expect = TestExpect.SKIP
1068        elif self.dtype in meta_dispatch_expected_failures.get(func, set()):
1069            test_expect = TestExpect.XFAILURE
1070        elif self.dtype in meta_dispatch_device_expected_failures[self.device_type].get(func, set()):
1071            test_expect = TestExpect.XFAILURE
1072        else:
1073            test_expect = TestExpect.SUCCESS
1074        return test_expect
1075
1076    def __torch_dispatch__(self, func, types, args=(), kwargs=None):
1077        kwargs = kwargs or {}
1078        self.test_case.precision = self.precision
1079        self.test_case.rel_tol = self.rel_tol
1080
1081        test_expect = self._get_expected_test_result(func)
1082
1083        expected = run_meta_crossref(
1084            self.test_case,
1085            test_expect,
1086            func,
1087            args,
1088            kwargs,
1089            dtype=self.dtype,
1090            device_type=self.device_type,
1091            run_symbolic_meta=self.symbolic_meta,
1092        )
1093
1094        # This is to test torch ops that do not have an out parameter but have
1095        # aten op overloads that have out parameters. Additionally, Python decompositions
1096        # may register OpOverloadPacket's so decompositions need to be tested
1097        # to ensure all OpOverloads still function for the Meta key (e.g. if a python decomposition
1098        # is registered for an aten op aten.foo with overloads [default, out], the python
1099        # function needs to support receiving `out` arguments)
1100        if (
1101            not self.inplace and
1102            not self.supports_out and
1103            test_expect == TestExpect.SUCCESS and
1104            (torch.is_tensor(expected) or isinstance(expected, Iterable))
1105        ):
1106
1107            # check to see if there is a potential out overload
1108            num_outputs = 1 if torch.is_tensor(expected) else len(expected)
1109            func_out_overload, out_param_names, kwargs = self.try_resolve_aten_out_overload(func, args, kwargs, num_outputs)
1110
1111            if func_out_overload:
1112
1113                if num_outputs == 1:
1114                    kwargs[out_param_names[0]] = expected
1115                else:
1116                    for ind, out_param_name in enumerate(out_param_names):
1117                        kwargs[out_param_name] = expected[ind]
1118
1119                test_expect = self._get_expected_test_result(func_out_overload)
1120
1121                run_meta_crossref(
1122                    self.test_case,
1123                    test_expect,
1124                    func_out_overload,
1125                    args,
1126                    kwargs,
1127                    dtype=self.dtype,
1128                    device_type=self.device_type,
1129                    run_symbolic_meta=self.symbolic_meta,
1130                )
1131
1132        return expected
1133
1134# NB: we're running these tests only on CUDA because there are some
1135# inconsistencies between CUDA and CPU, and running on CUDA makes it easier
1136# to ignore the CPU case when inconsistencies arise.  Ideally we deal
1137# with the inconsistencies but this takes time.
1138@unMarkDynamoStrictTest
1139class TestMeta(TestCase):
1140    # Copies inputs to inplace operations to avoid inplace modifications
1141    #   to leaves requiring gradient
1142    def _get_safe_inplace(self, inplace_variant):
1143        @wraps(inplace_variant)
1144        def _fn(t, *args, **kwargs):
1145            if isinstance(t, list):
1146                return inplace_variant([x.clone() for x in t], *args, **kwargs)
1147            else:
1148                return inplace_variant(t.clone(), *args, **kwargs)
1149
1150        return _fn
1151
1152    @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
1153    @skipIfCrossRef
1154    @suppress_warnings
1155    @ops(itertools.chain(op_db, foreach_op_db))
1156    def test_meta_outplace(self, device, dtype, op):
1157        if "_scaled_mm" in op.name:
1158            raise unittest.SkipTest("_scaled_mm dose not support meta device")
1159        skip_op_names = (
1160            "fft.ihfft",
1161            "fft.ihfft2",
1162            "linalg.lu_solve",
1163        )
1164        if TEST_WITH_TORCHDYNAMO and op.name in skip_op_names:
1165            raise unittest.SkipTest("flaky")
1166        # run the OpInfo sample inputs, cross-referencing them with the
1167        # meta implementation and check the results are the same.  All
1168        # the heavy lifting happens in MetaCrossRefFunctionMode
1169        func = op.get_op()
1170        samples = op.sample_inputs(device, dtype, requires_grad=False)
1171        for sample_input in samples:
1172            args = [sample_input.input] + list(sample_input.args)
1173            kwargs = sample_input.kwargs
1174            with MetaCrossRefFunctionMode(self, dtype=dtype, device=device, inplace=False):
1175                expected = func(*args, **kwargs)
1176                if isinstance(expected, torch.Tensor) and op.supports_out:
1177                    func(*args, **kwargs, out=expected)
1178
1179            # Special test for functions taking "device" kwarg
1180            # The crossref tests that replacing the device with "meta" works
1181            # This part makes sure that *_like functions work well with a "meta"
1182            # Tensor and their original device argument.
1183            if "device" in kwargs and "_like" in op.name:
1184                with torch.random.fork_rng():
1185                    torch.manual_seed(123)
1186                    ref = func(*args, **kwargs)
1187
1188                # *_like functions take a Tensor as first argument
1189                assert isinstance(args[0], torch.Tensor)
1190                with torch.random.fork_rng():
1191                    torch.manual_seed(123)
1192                    args[0] = args[0].to(device="meta")
1193                    meta = func(*args, **kwargs)
1194
1195                # empty_like is not deterministic
1196                if op.name != "empty_like":
1197                    self.assertEqual(ref, meta)
1198
1199    @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
1200    @skipIfCrossRef
1201    @suppress_warnings
1202    @ops(itertools.chain(op_db, foreach_op_db))
1203    def test_meta_inplace(self, device, dtype, op):
1204        func = op.get_inplace()
1205        if not func:
1206            self.skipTest("No inplace variable for this op")
1207        if op.promotes_int_to_float and not dtype.is_floating_point:
1208            self.skipTest("Op promotes to float, which is impossible for inplace with non-float input")
1209        if func in meta_inplace_skips:
1210            self.skipTest("Skipped")
1211        func = self._get_safe_inplace(func)
1212        samples = op.sample_inputs(device, dtype, requires_grad=False)
1213        for sample_input in samples:
1214            if sample_input.broadcasts_input:
1215                continue
1216            args = [sample_input.input] + list(sample_input.args)
1217            kwargs = sample_input.kwargs
1218            with MetaCrossRefFunctionMode(self, dtype=dtype, device=device, inplace=True):
1219                expected = func(*args, **kwargs)
1220
1221    def _run_dispatch_meta_test(self, device, dtype, op, symbolic_meta, inplace, all_stride_variants=False):
1222        if "_scaled_mm" in op.name:
1223            raise unittest.SkipTest("_scaled_mm dose not support meta device")
1224        if inplace:
1225            func = op.get_inplace()
1226            if not func:
1227                self.skipTest("No inplace variable for this op")
1228            if op.promotes_int_to_float and not dtype.is_floating_point:
1229                self.skipTest("Op promotes to float, which is impossible for inplace with non-float input")
1230        else:
1231            func = op.get_op()
1232
1233        if func in meta_dispatch_early_skips:
1234            self.skipTest("Function is in dispatch early skips")
1235
1236        if inplace:
1237            func = self._get_safe_inplace(func)
1238
1239        samples = op.sample_inputs(device, dtype, requires_grad=False)
1240        for sample_input in samples:
1241            if inplace and sample_input.broadcasts_input:
1242                continue
1243
1244            sample_args = [sample_input.input] + list(sample_input.args)
1245            kwargs = sample_input.kwargs
1246
1247            if all_stride_variants and sum(isinstance(arg, torch.Tensor) for arg in sample_args) <= 5:
1248                # test inputs <= 5 tensors to avoid combinatorial explosion
1249                strided_args = get_strided_args(sample_args)
1250            else:
1251                strided_args = [sample_args]
1252
1253            for args in strided_args:
1254                with MetaCrossRefDispatchMode.push(
1255                    self, dtype=dtype, device=device,
1256                    symbolic_meta=symbolic_meta, inplace=inplace,
1257                     supports_out=op.supports_out):
1258                    expected = func(*args, **kwargs)
1259
1260                    if not inplace and isinstance(expected, torch.Tensor) and op.supports_out:
1261                        func(*args, **kwargs, out=expected)
1262
1263
1264    @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
1265    @skipIfCrossRef
1266    @suppress_warnings
1267    @ops(itertools.chain(op_db, foreach_op_db))
1268    def test_dispatch_meta_outplace(self, device, dtype, op):
1269        self._run_dispatch_meta_test(device, dtype, op, symbolic_meta=False, inplace=False)
1270
1271    @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
1272    @skipIfCrossRef
1273    @suppress_warnings
1274    @ops(itertools.chain(op_db, foreach_op_db))
1275    def test_dispatch_meta_inplace(self, device, dtype, op):
1276        self._run_dispatch_meta_test(device, dtype, op, symbolic_meta=False, inplace=True)
1277
1278    @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
1279    @skipIfCrossRef
1280    @suppress_warnings
1281    @ops(itertools.chain(op_db, foreach_op_db))
1282    def test_dispatch_symbolic_meta_outplace(self, device, dtype, op):
1283        self._run_dispatch_meta_test(device, dtype, op, symbolic_meta=True, inplace=False)
1284
1285
1286    @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
1287    @skipIfCrossRef
1288    @suppress_warnings
1289    @ops(itertools.chain(op_db, foreach_op_db))
1290    def test_dispatch_symbolic_meta_inplace(self, device, dtype, op):
1291        self._run_dispatch_meta_test(device, dtype, op, symbolic_meta=True, inplace=True)
1292
1293    @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
1294    @skipIfCrossRef
1295    @suppress_warnings
1296    # only test one dtype, as output stride behavior is the same for all dtypes
1297    @ops(itertools.chain(op_db, foreach_op_db), dtypes=OpDTypes.any_common_cpu_cuda_one)
1298    # Only test on CUDA, as CUDA kernel's stride is the reference
1299    @onlyCUDA
1300    def test_dispatch_symbolic_meta_outplace_all_strides(self, device, dtype, op):
1301        self._run_dispatch_meta_test(device, dtype, op, symbolic_meta=True, inplace=False, all_stride_variants=True)
1302
1303    @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
1304    @skipIfCrossRef
1305    @suppress_warnings
1306    # only test one dtype, as output stride behavior is the same for all dtypes
1307    @ops(itertools.chain(op_db, foreach_op_db), dtypes=OpDTypes.any_common_cpu_cuda_one)
1308    # Only test on CUDA, as CUDA kernel's stride is the reference
1309    @onlyCUDA
1310    def test_dispatch_symbolic_meta_inplace_all_strides(self, device, dtype, op):
1311        self._run_dispatch_meta_test(device, dtype, op, symbolic_meta=True, inplace=True, all_stride_variants=True)
1312
1313    @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
1314    @skipIfCrossRef
1315    @suppress_warnings
1316    # only test one dtype, as output stride behavior is the same for all dtypes
1317    @ops(binary_ufuncs, allowed_dtypes=(torch.float32,))
1318    # Only test on CUDA, as CUDA kernel's stride is the reference
1319    @onlyCUDA
1320    def test_binary_ufuncs_mixed_dtype(self, device, dtype, op):
1321        make_arg = partial(
1322            make_tensor,
1323            device=device,
1324        )
1325
1326        def sample_input(op, device, dtype, requires_grad, **kwargs):
1327            yield SampleInput(
1328                make_arg((S,), dtype=dtype), make_arg((S,), dtype=torch.float16)
1329            )
1330
1331        op = copy.copy(op)
1332        op.sample_inputs_func = sample_input
1333
1334        self._run_dispatch_meta_test(device, dtype, op, symbolic_meta=True, inplace=False)
1335
1336
1337    def test_empty_quantized(self):
1338        r = torch.empty(2 ** 52, device='meta', dtype=torch.qint8)
1339        self.assertEqual(r.device.type, 'meta')
1340
1341    def test_nan_to_num(self):
1342        t = torch.tensor([float('nan'), float('inf'), -float('inf'), 3.14], device='meta')
1343        r = t.nan_to_num()
1344        self.assertEqual(r.device.type, 'meta')
1345
1346    def test_inplace_masked_fill_error(self):
1347        t = torch.randn(3, 3, device='meta')
1348        with self.assertRaisesRegex(RuntimeError, "doesn't match the broadcast"):
1349            t.masked_fill_((t > 0).unsqueeze(0), 0.1)
1350
1351    def test_inplace_bin_ops_error(self):
1352        t = torch.randn(3, 3, device='meta')
1353        for op in (torch.Tensor.add_, torch.Tensor.sub_, torch.Tensor.mul_, torch.Tensor.div_,
1354                   torch.Tensor.logical_and_, torch.Tensor.logical_or_, torch.Tensor.logical_xor_):
1355            with self.assertRaisesRegex(RuntimeError, "doesn't match the broadcast"):
1356                op(t, t.clone().unsqueeze(0))
1357
1358    @onlyCPU
1359    def test_meta_autograd_no_error(self):
1360        with torch.library._scoped_library("meta_test", "DEF") as lib:
1361            with torch.library._scoped_library("meta_test", "IMPL", "CPU") as impl_cpu:
1362                with torch.library._scoped_library("meta_test", "IMPL", "Meta") as impl_meta:
1363                    def foo_impl(x):
1364                        return x + 1
1365
1366                    lib.define("foo(Tensor a) -> Tensor")
1367                    impl_meta.impl("foo", foo_impl)
1368                    impl_cpu.impl("foo", foo_impl)
1369
1370                    a = torch.ones(2, device='meta')
1371                    # The point of the test is that this should not error:
1372                    # We have a fallthrough kernel registered to the AutogradMeta
1373                    # key for custom ops, so it's fine that `foo()` doesn't have
1374                    # an autograd kernel.
1375                    b = torch.ops.meta_test.foo.default(a)
1376
1377    def test_huber_loss_backward(self):
1378        inps = [torch.rand(2**52, device='meta') for _ in range(3)]
1379        r = torch.ops.aten.huber_loss_backward(*inps, 0, 1.0)
1380        self.assertEqual(r.device.type, 'meta')
1381        self.assertEqual(r.shape, inps[0].shape)
1382
1383    def _norm_backwards_test_helper(self, op, args, output_mask, expected_shapes):
1384
1385        dtype = torch.float32
1386        device = "meta"
1387
1388        # test functional call
1389        grads = op(*args, output_mask)
1390
1391        def assertEqualShapes(res, exp):
1392            self.assertIsNone(res) if exp is None else self.assertEqual(exp, res.shape)
1393
1394        assertEqualShapes(grads[0], expected_shapes[0])
1395        assertEqualShapes(grads[1], expected_shapes[1])
1396        assertEqualShapes(grads[2], expected_shapes[2])
1397
1398        out_kwargs = {
1399            f"out{i}": torch.empty(0, device=device, dtype=dtype)
1400            for i in range(len(output_mask))
1401        }
1402
1403        # test call with out parameters
1404        grads = op(*args, output_mask, **out_kwargs)
1405
1406        def assertEqualShapes(res, exp):
1407            self.assertEqual(exp, res.shape) if exp is not None else True
1408
1409        assertEqualShapes(out_kwargs["out0"], expected_shapes[0])
1410        assertEqualShapes(out_kwargs["out1"], expected_shapes[1])
1411        assertEqualShapes(out_kwargs["out2"], expected_shapes[2])
1412
1413    @onlyCPU
1414    @parametrize("output_mask", list(itertools.product([True, False], [True, False], [True, False])))
1415    def test_layer_norm_backward(self, output_mask):
1416        from torch.testing._internal.common_methods_invocations import sample_inputs_layer_norm
1417
1418        device = "meta"
1419        dtype = torch.float32
1420
1421        samples = sample_inputs_layer_norm(None, device, dtype, requires_grad=False)
1422
1423        for sample in samples:
1424            with self.subTest(sample=sample):
1425                # handle optional weight and bias
1426                if len(sample.args) != 3:
1427                    sample.args = (*sample.args, *([None] * (3 - len(sample.args))))
1428
1429                grad_out = torch.ones_like(sample.input)
1430                normalized_shape, weight, bias = sample.args
1431                ndims_after_reduction = sample.input.ndim - len(normalized_shape)
1432                mean_shape = grad_out.shape[:ndims_after_reduction]
1433                mean = torch.zeros(mean_shape, device=device, dtype=dtype)
1434                rstd = torch.zeros(mean_shape, device=device, dtype=dtype)
1435
1436                expected_shapes = (
1437                    sample.input.shape if output_mask[0] else None,
1438                    weight.shape if output_mask[1] and weight is not None else None,
1439                    bias.shape if output_mask[2] and bias is not None else None)
1440
1441                args = [grad_out, sample.input, normalized_shape, mean, rstd, weight, bias]
1442
1443                self._norm_backwards_test_helper(torch.ops.aten.native_layer_norm_backward,
1444                                                 args, output_mask, expected_shapes)
1445
1446    @onlyCPU
1447    @parametrize("output_mask", list(itertools.product([True, False], [True, False], [True, False])))
1448    def test_group_norm_backward(self, output_mask):
1449        from torch.testing._internal.common_methods_invocations import sample_inputs_group_norm
1450
1451        # input, (args) num_groups, (kwargs) weight, bias eps
1452        device = "meta"
1453        dtype = torch.float32
1454        samples = sample_inputs_group_norm(None, device, dtype, requires_grad=False)
1455
1456        for sample in samples:
1457            with self.subTest(sample=sample):
1458                grad_out = torch.ones_like(sample.input)
1459                N, C = sample.input.shape[:2]
1460                HxW = torch.prod(torch.as_tensor(sample.input.shape[2:]), dtype=torch.int32).item()
1461                group = sample.args[0]
1462                mean = torch.zeros((N, group), device=device, dtype=dtype)
1463                rstd = torch.zeros((N, group), device=device, dtype=dtype)
1464                weight = torch.zeros((C), device=device, dtype=dtype)
1465
1466                args = [grad_out, sample.input, mean, rstd, weight, N, C, HxW, group]
1467
1468                expected_shapes = (
1469                    sample.input.shape if output_mask[0] else None,
1470                    weight.shape if output_mask[1] else None,
1471                    weight.shape if output_mask[2] else None)
1472
1473                # test functional call
1474                self._norm_backwards_test_helper(torch.ops.aten.native_group_norm_backward,
1475                                                 args, output_mask, expected_shapes)
1476
1477    @onlyCPU
1478    @parametrize("output_mask", list(itertools.product([True], [True, False], [True, False])))
1479    def test_batch_norm_backward(self, output_mask):
1480        from torch.testing._internal.common_methods_invocations import sample_inputs_batch_norm
1481
1482        # input, (args) num_groups, (kwargs) weight, bias eps
1483        device = "meta"
1484        dtype = torch.float32
1485        samples = sample_inputs_batch_norm(None, device, dtype, requires_grad=False)
1486
1487        for sample in samples:
1488            with self.subTest(sample=sample):
1489
1490                if sample.input.dim() < 2:
1491                    continue
1492
1493                grad_out = torch.ones_like(sample.input)
1494                running_mean, running_var, weight, bias = sample.args
1495                train = sample.kwargs.get("training", True)
1496                save_mean = torch.zeros((sample.input.shape[1], ), device=device, dtype=dtype) if train else None
1497                save_invstd = torch.zeros((sample.input.shape[1], ), device=device, dtype=dtype) if train else None
1498
1499                args = [grad_out, sample.input, weight, running_mean, running_var,
1500                        save_mean, save_invstd, train, sample.kwargs.get("eps", 1e-5)]
1501
1502                expected_shapes = (
1503                    sample.input.shape,
1504                    torch.Size([sample.input.shape[1]]) if output_mask[1] else None,
1505                    torch.Size([sample.input.shape[1]]) if output_mask[2] else None)
1506
1507                self._norm_backwards_test_helper(torch.ops.aten.native_batch_norm_backward,
1508                                                 args, output_mask, expected_shapes)
1509
1510    def test_fill__alias_relationship(self):
1511        inps = torch.rand(2**52, device='meta')
1512        r = torch.ops.aten.fill_(inps, 1.0)
1513        # aten.fill_ returns an aliase
1514        self.assertEqual(id(inps), id(r))
1515
1516        # aten.fill returns a new tensor
1517        r2 = torch.ops.aten.fill(inps, 1.0)
1518        self.assertNotEqual(id(inps), id(r2))
1519
1520    def test_meta__fused_moving_avg_obs_fq_helper(self, device):
1521        from torch.ao.quantization import FusedMovingAvgObsFakeQuantize
1522        to_meta = MetaConverter()
1523
1524        x = torch.randn(5, 5, device=device)
1525        running_min_op = torch.tensor(float("inf"), device=device)
1526        running_max_op = torch.tensor(float("-inf"), device=device)
1527        avg_const = 0.01
1528        scale = torch.tensor([1.0], device=device)
1529        zero_point = torch.tensor([0], dtype=torch.int, device=device)
1530
1531        mod = FusedMovingAvgObsFakeQuantize()
1532        torch.ao.quantization.enable_fake_quant(mod)
1533        torch.ao.quantization.enable_observer(mod)
1534        mod.to(device)
1535
1536        meta_x = to_meta(x)
1537
1538        args = [
1539            x,
1540            mod.observer_enabled,
1541            mod.fake_quant_enabled,
1542            running_min_op,
1543            running_max_op,
1544            scale,
1545            zero_point,
1546            avg_const,
1547            0,
1548            255,
1549            0,
1550        ]
1551
1552        meta_args = args.copy()
1553        meta_args[0] = meta_x
1554
1555        kwargss = [
1556            {},
1557            {"per_row_fake_quant": False, "symmetric_quant": False},
1558            {"per_row_fake_quant": False, "symmetric_quant": True},
1559        ]
1560
1561        for kwargs in kwargss:
1562            ref_out = aten._fused_moving_avg_obs_fq_helper.default(*args, **kwargs)
1563            meta_out = aten._fused_moving_avg_obs_fq_helper.default(*meta_args, **kwargs)
1564
1565            self.assertEqual(ref_out[0].size(), meta_out[0].size())
1566            self.assertEqual(ref_out[0].stride(), meta_out[0].stride())
1567            self.assertEqual(ref_out[1].size(), meta_out[1].size())
1568            self.assertEqual(ref_out[1].stride(), meta_out[1].stride())
1569
1570    def test_cdist_forward(self, device):
1571        to_meta = MetaConverter()
1572        x1 = torch.rand([3, 2], device=device)
1573        x2 = torch.rand([2, 2], device=device)
1574        p = 2.0
1575        for compute_mode in (None, 1, 2):
1576            ref = aten._cdist_forward.default(x1, x2, p, compute_mode)
1577            res = aten._cdist_forward.default(to_meta(x1), to_meta(x2), p, compute_mode)
1578            self.assertEqual(res.device.type, 'meta')
1579            self.assertEqual(ref.shape, res.shape)
1580
1581    def test_quantized_embedding_bag(self):
1582        tab_shape = [8, 128]
1583        emb_size, ind_len, off_len = tab_shape[0], 32, 33
1584        f_table = torch.from_numpy((np.random.random_sample(tab_shape) + 1).astype(np.float32))
1585        q_table = torch.ops.quantized.embedding_bag_byte_prepack(f_table)
1586        indices = torch.from_numpy(np.random.randint(low=0, high=emb_size, size=ind_len)).int()
1587        max_length = len(indices) // (off_len - 1)
1588        if max_length > 20:
1589            max_length = 20
1590        np_lengths = np.random.randint(0, max_length + 1, size=off_len - 1).astype(np.int32)
1591        offsets = torch.cat([torch.zeros([1]), torch.cumsum(torch.from_numpy(np_lengths), 0)]).int()
1592
1593        eb = torch.ops.quantized.embedding_bag_byte_rowwise_offsets(
1594            q_table.to(device="meta"),
1595            indices.to(device="meta"),
1596            offsets.to(device="meta"),
1597            mode=0,  # sum
1598            per_sample_weights=None,
1599            include_last_offset=True,
1600        )
1601        self.assertEqual(eb.shape, [32, 128])
1602        self.assertEqual(eb.dtype, torch.float32)
1603        self.assertEqual(eb.untyped_storage().data_ptr(), 0)
1604
1605    # Tests mean and max.
1606    # Can't easily test sum, because there is a fast path for sum which
1607    # causes offset2bag to not get allocated... but the backward function
1608    # needs it, and the offset2bag computation lives inside the
1609    # derivatives.yaml formula directly, so there is no way to access it.
1610    # To test sum, need to manually compute offset2bag
1611    @parametrize("mode", [1, 2])
1612    def test_embedding_bag_dense_backward(self, mode):
1613        weight = torch.randn(4, 3, requires_grad=True)
1614        indices = torch.tensor([1, 0, 2, 1, 3])
1615        offsets = torch.tensor([0, 2, 3, 5])
1616        scale_grad_by_freq = False
1617        sparse = False
1618        per_sample_weights = None
1619        include_last_offset = False
1620        padding_idx = -1
1621
1622        output, offset2bag, bag_size, maximum_indices = torch.ops.aten._embedding_bag.default(
1623            weight, indices, offsets, scale_grad_by_freq, mode, sparse, per_sample_weights, include_last_offset, padding_idx
1624        )
1625        grad = torch.randn_like(output)
1626
1627        # Call the function with example inputs
1628        grad_weight = torch.ops.aten._embedding_bag_dense_backward.default(
1629            grad, indices, offset2bag, bag_size, maximum_indices, weight.size(0),
1630            scale_grad_by_freq, mode, per_sample_weights, padding_idx
1631        )
1632        meta_grad_weight = torch.ops.aten._embedding_bag_dense_backward.default(
1633            grad.to('meta'), indices.to('meta'), offset2bag.to('meta'), bag_size.to('meta'),
1634            maximum_indices.to('meta'), weight.size(0),
1635            scale_grad_by_freq, mode, per_sample_weights, padding_idx
1636        )
1637        self.assertEqual(grad_weight.to('meta'), meta_grad_weight)
1638
1639    def test_embedding_bag_dense_backward_per_sample_weights(self):
1640        weight = torch.randn(4, 3, requires_grad=True)
1641        indices = torch.tensor([1, 0, 2, 1, 3])
1642        offsets = torch.tensor([0, 2, 3, 5])
1643        scale_grad_by_freq = False
1644        sparse = False
1645        mode = 0
1646        per_sample_weights = torch.randn(5, requires_grad=True)
1647        include_last_offset = False
1648        padding_idx = -1
1649
1650        output, offset2bag, bag_size, maximum_indices = torch.ops.aten._embedding_bag.default(
1651            weight, indices, offsets, scale_grad_by_freq, mode, sparse, per_sample_weights, include_last_offset, padding_idx
1652        )
1653        grad = torch.randn_like(output)
1654
1655        # Call the function with example inputs
1656        grad_weight = torch.ops.aten._embedding_bag_per_sample_weights_backward.default(
1657            grad, weight, indices, offsets, offset2bag, mode, padding_idx
1658        )
1659        meta_grad_weight = torch.ops.aten._embedding_bag_per_sample_weights_backward.default(
1660            grad.to('meta'), weight.to('meta'), indices.to('meta'),
1661            offsets.to('meta'), offset2bag.to('meta'), mode, padding_idx
1662        )
1663        self.assertEqual(grad_weight.to('meta'), meta_grad_weight)
1664
1665    # opinfo test is using aten.fill_, it's not testing aten.fill
1666    @onlyCUDA
1667    def test_fill_stride(self):
1668        to_meta = MetaConverter()
1669        sample_args = [torch.rand(2, 2, 2, 2), 1.0]
1670
1671        for args in get_strided_args(sample_args):
1672            meta_args = to_meta(args)
1673            ref_out = torch.ops.aten.fill(*args)
1674            meta_out = torch.ops.aten.fill(*meta_args)
1675            self.assertEqual(ref_out.size(), meta_out.size())
1676            self.assertEqual(ref_out.stride(), meta_out.stride())
1677
1678
1679    def test_map_location_deserialize(self):
1680        import io
1681
1682        t = torch.rand(10)
1683        b = io.BytesIO()
1684
1685        torch.save(t, b)
1686        b.seek(0)
1687        r = torch.load(b, map_location=torch.device("meta"))
1688        self.assertEqual(r.device.type, 'meta')
1689        self.assertEqual(r.shape, t.shape)
1690        self.assertEqual(r.dtype, t.dtype)
1691        self.assertEqual(r.storage().data_ptr(), 0)
1692
1693    def test_embedding_bag_byte_prepack(self):
1694        batch_size = 10
1695        num_embeddings = 80
1696        embedding_dim = [128, 256, 512]
1697        res_shape = [[batch_size, num_embeddings, ed + 8] for ed in embedding_dim]
1698        for ed, rs in zip(embedding_dim, res_shape):
1699            weight = torch.randn(batch_size, num_embeddings, ed, dtype=torch.float32)
1700            res = torch.ops.quantized.embedding_bag_byte_prepack(weight.to(device="meta"))
1701            self.assertEqual(res.shape, rs)
1702            self.assertEqual(res.dtype, torch.float32)
1703            self.assertEqual(res.untyped_storage().data_ptr(), 0)
1704
1705    def test_embedding_bag_byte_unpack(self):
1706        batch_size = 10
1707        num_embeddings = 80
1708        embedding_dim = [128, 256, 512]
1709        res_shape = [[batch_size, num_embeddings, ed] for ed in embedding_dim]
1710        for ed, rs in zip(embedding_dim, res_shape):
1711            packed_weight = torch.randn(batch_size, num_embeddings, ed + 8, dtype=torch.float32)
1712            res = torch.ops.quantized.embedding_bag_byte_unpack(packed_weight.to(device="meta"))
1713            self.assertEqual(res.shape, rs)
1714            self.assertEqual(res.dtype, torch.float32)
1715            self.assertEqual(res.untyped_storage().data_ptr(), 0)
1716
1717    def test_index_select_out(self):
1718        def f():
1719            input = torch.randn([8, 16], device='meta')
1720            index = torch.tensor([2, 1, 6, 7, 3, 1, 7, 5, 6, 7], device='meta')
1721            out = torch.empty([10, 16], device='meta')
1722            return torch.index_select(input=input, dim=0, index=index, out=out)
1723        with enable_python_dispatcher():
1724            out = f()
1725            self.assertEqual(out.shape, [10, 16])
1726
1727    def test_local_scalar_dense_call(self):
1728        with self.assertRaisesRegex(RuntimeError, "cannot be called on meta tensors"):
1729            meta_tensor = torch.randn(1, device='meta')
1730            meta_tensor.item()
1731
1732instantiate_device_type_tests(TestMeta, globals())
1733
1734def print_op_str_if_not_supported(op_str):
1735    op = OperatorName.parse(op_str)
1736    packet = getattr(torch.ops.aten, str(op.name))
1737    overload = getattr(packet, op.overload_name if op.overload_name else "default")
1738    if any(overload in d for d in [meta_dispatch_skips, meta_dispatch_device_skips['cuda']]):
1739        print(f"{overload}  # SKIP")
1740    if any(overload in d for d in [meta_dispatch_expected_failures, meta_dispatch_device_expected_failures['cuda']]):
1741        print(overload)
1742
1743
1744if __name__ == "__main__":
1745    COMPARE_XLA = os.getenv('PYTORCH_COMPARE_XLA', None)
1746    if COMPARE_XLA is not None:
1747        with open(COMPARE_XLA) as f:
1748            d = yaml.load(f, Loader=YamlLoader)
1749            ops = d.get("full_codegen", []) + d.get("supported", []) + d.get("autograd", [])
1750            for op_str in ops:
1751                print_op_str_if_not_supported(op_str)
1752        sys.exit(0)
1753
1754    COMPARE_TEXT = os.getenv('PYTORCH_COMPARE_TEXT', None)
1755    if COMPARE_TEXT is not None:
1756        with open(COMPARE_TEXT) as f:
1757            for op_str in f:
1758                print_op_str_if_not_supported(op_str.strip())
1759        sys.exit(0)
1760
1761    run_tests()
1762