xref: /aosp_15_r20/external/pytorch/test/test_foreach.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: mta"]
2
3import itertools
4import os
5import random
6import re
7import unittest
8import weakref
9from contextlib import nullcontext
10from numbers import Number
11
12import torch
13from torch.testing import make_tensor
14from torch.testing._comparison import default_tolerances
15from torch.testing._internal.common_cuda import TEST_MULTIGPU
16from torch.testing._internal.common_device_type import (
17    dtypes,
18    instantiate_device_type_tests,
19    onlyCUDA,
20    OpDTypes,
21    ops,
22)
23from torch.testing._internal.common_dtype import (
24    all_types_and_complex_and,
25    floating_types,
26    floating_types_and,
27    integral_types_and,
28)
29from torch.testing._internal.common_methods_invocations import (
30    foreach_binary_op_db,
31    foreach_other_op_db,
32    foreach_pointwise_op_db,
33    foreach_reduce_op_db,
34    foreach_unary_op_db,
35)
36from torch.testing._internal.common_utils import (
37    gradcheck,
38    parametrize,
39    run_tests,
40    skipIfRocmVersionLessThan,
41    skipIfTorchDynamo,
42    TEST_WITH_ROCM,
43    TestCase,
44)
45
46
47_BOOL_SUB_ERR_MSG = "Subtraction, the `-` operator"
48
49
50class RegularFuncWrapper:
51    def __init__(self, func):
52        self.func = func
53
54    def __call__(self, inputs, scalars=None, **kwargs):
55        if scalars is not None:
56            assert len(inputs) == 3
57            # We need to distribute each scalar to the regular func and it needs
58            # special consideration as it is a keyword only argument to the
59            # regular func. (Strangely, it is not a keyword only argument to the
60            # foreach func)
61            return [
62                self.func(*i, value=scalars[idx], **kwargs)
63                for idx, i in enumerate(zip(*inputs))
64            ]
65        if len(inputs) == 2 and isinstance(inputs[1], (Number, torch.Tensor)):
66            # binary op with tensorlist and scalar.
67            inputs[1] = [inputs[1] for _ in range(len(inputs[0]))]
68        return [self.func(*i, **kwargs) for i in zip(*inputs)]
69
70
71class ForeachFuncWrapper:
72    def __init__(self, func):
73        self.func = func
74        # Some foreach functions don't have in-place implementations.
75        self.is_inplace = False if func is None else func.__name__.endswith("_")
76
77    def __call__(self, inputs, is_cuda, expect_fastpath, **kwargs):
78        actual = None
79        zero_size = kwargs.pop("zero_size", False)
80        if (
81            is_cuda
82            and torch.autograd.kineto_available()
83            and torch.profiler.ProfilerActivity.CUDA
84            in torch.profiler.supported_activities()
85        ):
86            with torch.profiler.profile() as p:
87                actual = self.func(*inputs, **kwargs)
88            keys = tuple([e.key for e in p.key_averages()])
89            mta_called = any("multi_tensor_apply_kernel" in k for k in keys)
90            assert (
91                mta_called == (expect_fastpath and (not zero_size))
92            ), f"{mta_called=}, {expect_fastpath=}, {zero_size=}, {self.func.__name__=}, {keys=}"
93        else:
94            actual = self.func(*inputs, **kwargs)
95        if self.is_inplace:
96            assert id(inputs[0]) == id(actual)
97        return actual
98
99
100class InplaceForeachVersionBumpCheck:
101    def __init__(
102        self,
103        testcase: TestCase,
104        tensorlist: "List[torch.Tensor]",  # noqa: F821
105    ) -> None:
106        self._testcase = testcase
107        self._tensorlist = tensorlist
108        self._orig_version_counts = [t._version for t in tensorlist]
109
110    def __enter__(self):
111        pass
112
113    def __exit__(self, exc_type, exc_value, traceback):
114        # note(crcrpar): some methods e.g. `_binary_test` could call the given inplace function multiple times
115        self._testcase.assertGreaterEqual(
116            [t._version for t in self._tensorlist], self._orig_version_counts
117        )
118
119
120def get_transform_func(num_tensors, dtype, device, is_fastpath):
121    def transform(t):
122        if not torch.is_tensor(t):
123            return t
124        if torch.is_tensor(t) and t.ndim == 0:
125            return t
126        return make_tensor(
127            (num_tensors, num_tensors),
128            dtype=dtype,
129            device=device,
130            requires_grad=True,
131            noncontiguous=not is_fastpath,
132        )
133
134    return transform
135
136
137# note(crcrpar): `zero_size` is `False` unless (dtype, device) == (torch.float32, "cuda")
138# as the pair would go through `multi_tensor_apply_kernel` if inputs are not zero size.
139@unittest.mock.patch.dict(os.environ, {"KINETO_LOG_LEVEL": "5"})
140class TestForeach(TestCase):
141    @property
142    def is_cuda(self):
143        return self.device_type == "cuda"
144
145    def _get_funcs(self, op):
146        return (
147            ForeachFuncWrapper(op.method_variant),
148            RegularFuncWrapper(op.ref),
149            ForeachFuncWrapper(op.inplace_variant),
150            RegularFuncWrapper(op.ref_inplace),
151        )
152
153    # note(crcrpar): Make sure 0-size tensors are appropriately ignored by `multi_tensor_apply`
154    # which is originally reported in https://github.com/pytorch/pytorch/issues/94865.
155    # rel:
156    #   - https://github.com/pytorch/pytorch/pull/94655
157    #   - https://github.com/pytorch/pytorch/issues/100701
158    #   - https://github.com/pytorch/pytorch/pull/100811
159    @onlyCUDA
160    @ops(
161        foreach_unary_op_db
162        + foreach_binary_op_db
163        + foreach_pointwise_op_db
164        + foreach_reduce_op_db
165        + foreach_other_op_db,
166        dtypes=(torch.float32,),
167    )
168    def test_all_zero_size_tensors_do_not_launch_kernel(self, device, dtype, op):
169        wrapped_op, _, inplace_op, _ = self._get_funcs(op)
170
171        for sample in op.sample_zero_size_inputs(device, dtype):
172            if op.method_variant is not None:
173                wrapped_op(
174                    (sample.input, *sample.args),
175                    is_cuda=self.is_cuda,
176                    expect_fastpath=True,
177                    zero_size=True,
178                )
179
180            if op.inplace_variant is not None:
181                with InplaceForeachVersionBumpCheck(self, sample.input):
182                    inplace_op(
183                        (sample.input, *sample.args),
184                        is_cuda=self.is_cuda,
185                        expect_fastpath=True,
186                        zero_size=True,
187                    )
188
189    @skipIfRocmVersionLessThan((6, 0))
190    @ops(
191        foreach_unary_op_db
192        + foreach_binary_op_db
193        + foreach_pointwise_op_db
194        + foreach_reduce_op_db
195        + foreach_other_op_db,
196    )
197    @parametrize(
198        "noncontiguous,inplace",
199        [(False, False), (False, True), (True, False), (True, True)],
200        name_fn=lambda x, y: "{}_{}".format(
201            "fastpath" if not x else "slowpath", "inplace" if y else "outplace"
202        ),
203    )
204    def test_parity(self, device, dtype, op, noncontiguous, inplace):
205        if inplace:
206            _, _, func, ref = self._get_funcs(op)
207        else:
208            func, ref, _, _ = self._get_funcs(op)
209        for sample in op.sample_inputs(
210            device, dtype, noncontiguous=noncontiguous, allow_higher_dtype_scalars=True
211        ):
212            ref_kwargs = sample.kwargs
213            # div promotes ints to floats, so we cannot go on the fastpath there
214            div_slowpath = (
215                dtype in integral_types_and(torch.bool) and op.name == "_foreach_div"
216            )
217            expect_fastpath = not (
218                noncontiguous or sample.disable_fastpath or div_slowpath
219            )
220            ref_input, ctxmgr = sample.input, nullcontext()
221            if inplace:
222                with torch.no_grad():
223                    ref_input = [t.clone().detach() for t in sample.input]
224                ctxmgr = InplaceForeachVersionBumpCheck(self, sample.input)
225            try:
226                with ctxmgr:
227                    actual = func(
228                        [sample.input, *sample.args],
229                        self.is_cuda,
230                        expect_fastpath,
231                        **sample.kwargs,
232                    )
233            except Exception as e:
234                with self.assertRaises(type(e)):
235                    ref([ref_input, *sample.ref_args], **ref_kwargs)
236            else:
237                expected = ref([ref_input, *sample.ref_args], **ref_kwargs)
238                self.assertEqual(expected, actual)
239
240    def _binary_test(
241        self,
242        dtype,
243        op,
244        ref,
245        inputs,
246        is_fastpath,
247        is_inplace,
248        *,
249        alpha,
250        scalar_self_arg: bool,
251    ):
252        ref_inputs = (
253            [[t.clone().detach() for t in inputs[0]], inputs[1]]
254            if is_inplace
255            else inputs
256        )
257        try:
258            with InplaceForeachVersionBumpCheck(
259                self, inputs[0]
260            ) if op.is_inplace else nullcontext():
261                actual = op(inputs, self.is_cuda, is_fastpath)
262        except RuntimeError as e:
263            with self.assertRaisesRegex(type(e), re.escape(str(e).splitlines()[0])):
264                if not scalar_self_arg:
265                    ref(ref_inputs)
266                else:
267                    [ref.func(ref_inputs[0], t) for t in ref_inputs[1]]
268        else:
269            expected = (
270                ref(ref_inputs)
271                if not scalar_self_arg
272                else [ref.func(ref_inputs[0], t) for t in ref_inputs[1]]
273            )
274            self.assertEqual(actual, expected)
275        if alpha is not None and not scalar_self_arg:
276            kwargs = {"alpha": alpha}
277            ref_inputs = inputs
278            try:
279                op_kwargs = {}
280                op_kwargs.update(kwargs)
281                with InplaceForeachVersionBumpCheck(
282                    self, inputs[0]
283                ) if op.is_inplace else nullcontext():
284                    actual = op(inputs, self.is_cuda, is_fastpath, **op_kwargs)
285            except RuntimeError as e:
286                with self.assertRaisesRegex(type(e), re.escape(str(e).splitlines()[0])):
287                    ref(ref_inputs, **kwargs)
288            else:
289                expected = ref(ref_inputs, **kwargs)
290                if dtype in (torch.float16, torch.bfloat16) and TEST_WITH_ROCM:
291                    self.assertEqual(
292                        expected, actual, atol=1.0e-3, rtol=default_tolerances(dtype)[0]
293                    )
294                else:
295                    self.assertEqual(expected, actual)
296
297    @ops(filter(lambda op: op.supports_scalar_self_arg, foreach_binary_op_db))
298    @parametrize("is_fastpath", (True, False))
299    def test_binary_op_with_scalar_self_support(self, device, dtype, op, is_fastpath):
300        def clone(arg):
301            if isinstance(arg, (list, tuple)):
302                return [clone(a) for a in arg]
303            if torch.is_tensor(arg):
304                return arg.clone().detach().requires_grad_()
305            else:
306                return arg
307
308        scalar_self_arg_test_complete = False
309        for i, sample in enumerate(
310            op.sample_inputs(
311                device,
312                dtype,
313                noncontiguous=not is_fastpath,
314                allow_higher_dtype_scalars=True,
315            )
316        ):
317            (rhs_arg,) = sample.args
318            kwargs = {} or sample.kwargs
319            alpha = kwargs.pop("alpha", None)
320            wrapped_op, ref, inplace_op, inplace_ref = self._get_funcs(op)
321            if isinstance(rhs_arg, Number) and not scalar_self_arg_test_complete:
322                scalar_self_arg_test_complete = True
323                self._binary_test(
324                    dtype,
325                    wrapped_op,
326                    ref,
327                    [rhs_arg, sample.input],
328                    is_fastpath,
329                    False,
330                    alpha=alpha,
331                    scalar_self_arg=True,
332                )
333                if op.supports_autograd and dtype == torch.float32:
334                    transformed_sample = sample.transform(
335                        get_transform_func(
336                            len(sample.input), dtype, device, is_fastpath
337                        )
338                    )
339                    tensors = transformed_sample.input
340                    (rhs_arg,) = transformed_sample.args
341                    ref_tensors, ref_rhs_arg = clone(tensors), clone(rhs_arg)
342                    sum(
343                        wrapped_op(
344                            [rhs_arg, tensors], is_cuda=False, expect_fastpath=False
345                        )
346                    ).mean().backward()
347                    sum(ref.func(ref_rhs_arg, t) for t in ref_tensors).mean().backward()
348                    self.assertEqual(
349                        [t.grad for t in tensors], [t.grad for t in ref_tensors]
350                    )
351
352    @ops(foreach_pointwise_op_db)
353    @parametrize("is_fastpath", (True, False))
354    def test_pointwise_op_with_tensor_of_scalarlist_overload(
355        self, device, dtype, op, is_fastpath
356    ):
357        for sample in op.sample_inputs(
358            device,
359            dtype,
360            noncontiguous=not is_fastpath,
361            allow_higher_dtype_scalars=True,
362        ):
363            assert isinstance(sample.args, tuple)
364            assert len(sample.args) == 2
365            inputs = [sample.input, *sample.args]
366            kwargs = sample.kwargs.copy()
367            disable_fastpath = sample.disable_fastpath and is_fastpath
368            wrapped_op, ref, inplace_op, inplace_ref = self._get_funcs(op)
369            scalars = kwargs.pop("scalars", None)
370
371            if is_fastpath and scalars:
372                sample = sample.transform(
373                    lambda t: t.clone().detach() if torch.is_tensor(t) else t
374                )
375                inputs = [sample.input, *sample.args]
376                tensor_values = torch.tensor(scalars)
377                # 1D Tensor of scalars
378                for is_inplace, op_, ref_ in (
379                    (False, wrapped_op, ref),
380                    (True, inplace_op, inplace_ref),
381                ):
382                    self._pointwise_test(
383                        op_,
384                        ref_,
385                        inputs,
386                        is_fastpath and not disable_fastpath,
387                        is_inplace,
388                        scalars=tensor_values,
389                        **kwargs,
390                    )
391                    self._pointwise_test(
392                        op_,
393                        ref_,
394                        inputs,
395                        is_fastpath and not disable_fastpath,
396                        is_inplace,
397                        scalars=tensor_values[0],
398                        custom_values_err="Expected packed scalar Tensor to be of dimension 1. Got 0 instead.",
399                        **kwargs,
400                    )
401                    if self.is_cuda:
402                        self._pointwise_test(
403                            op_,
404                            ref_,
405                            inputs,
406                            is_fastpath and not disable_fastpath,
407                            is_inplace,
408                            scalars=tensor_values.cuda(),
409                            custom_values_err="Expected scalars to be on CPU, got cuda:0 instead.",
410                            **kwargs,
411                        )
412                    self._pointwise_test(
413                        op_,
414                        ref_,
415                        inputs,
416                        is_fastpath and not disable_fastpath,
417                        is_inplace,
418                        scalars=tensor_values[:2],
419                        custom_values_err=f"Expected length of scalars to match input of length {len(scalars)} but got 2 instead.",
420                        **kwargs,
421                    )
422                    self._pointwise_test(
423                        op_,
424                        ref_,
425                        inputs,
426                        is_fastpath and not disable_fastpath,
427                        is_inplace,
428                        scalars=torch.tensor([[0, 1], [2, 3]])[:, 1],
429                        custom_values_err="Expected scalars to be contiguous.",
430                        **kwargs,
431                    )
432
433            # Tests of implicit broadcasting
434            N = len(sample.input)
435            inputs = [
436                [
437                    make_tensor(
438                        (N, N),
439                        device=device,
440                        dtype=dtype,
441                        noncontiguous=not is_fastpath,
442                    )
443                    for _ in range(N)
444                ],
445                [
446                    make_tensor(
447                        (N - i, 1),
448                        device=device,
449                        dtype=dtype,
450                        noncontiguous=not is_fastpath,
451                    )
452                    for i in range(N)
453                ],
454                [
455                    make_tensor(
456                        (1, N - i),
457                        device=device,
458                        dtype=dtype,
459                        noncontiguous=not is_fastpath,
460                    )
461                    for i in range(N)
462                ],
463            ]
464            self._pointwise_test(
465                wrapped_op,
466                ref,
467                inputs,
468                is_fastpath and disable_fastpath,
469                is_inplace=False,
470                scalars=scalars,
471                **kwargs,
472            )
473            self._pointwise_test(
474                inplace_op,
475                inplace_ref,
476                inputs,
477                is_fastpath and disable_fastpath,
478                is_inplace=True,
479                scalars=scalars,
480                **kwargs,
481            )
482
483    def _pointwise_test(
484        self,
485        op,
486        ref,
487        inputs,
488        is_fastpath,
489        is_inplace,
490        *,
491        scalars=None,
492        custom_values_err=None,
493        **kwargs,
494    ):
495        ref_inputs = (
496            [[t.clone().detach() for t in inputs[0]], inputs[1], inputs[2]]
497            if is_inplace
498            else inputs
499        )
500        try:
501            with (
502                InplaceForeachVersionBumpCheck(self, inputs[0])
503                if is_inplace
504                else nullcontext()
505            ):
506                actual = op(inputs, self.is_cuda, is_fastpath, **kwargs)
507        except RuntimeError as e:
508            with self.assertRaisesRegex(type(e), re.escape(str(e).splitlines()[0])):
509                ref(ref_inputs, **kwargs)
510        else:
511            expected = ref(ref_inputs, **kwargs)
512            self.assertEqual(expected, actual)
513        if scalars is not None:
514            kwargs = kwargs.copy()
515            kwargs["scalars"] = scalars
516            try:
517                actual = op(inputs, self.is_cuda, is_fastpath, **kwargs)
518            except RuntimeError as e:
519                # Match with error messages from regular non-foreach reference if no
520                # custom error message was provided.
521                if custom_values_err is None:
522                    with self.assertRaisesRegex(
523                        type(e), re.escape(str(e).splitlines()[0])
524                    ):
525                        ref(ref_inputs, **kwargs)
526                else:
527                    self.assertEqual(re.escape(str(e)), re.escape(custom_values_err))
528            else:
529                expected = ref(ref_inputs, **kwargs)
530                self.assertEqual(expected, actual)
531
532    @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16))
533    def test_add_scalar_with_empty_list_and_empty_tensor(self, device, dtype):
534        # TODO: enable empty list case
535        for tensors in [
536            [torch.randn([0], device=device, dtype=dtype)],
537            [torch.empty_strided((0, 1), (0, 0), dtype=dtype, device=device)],
538        ]:
539            res = torch._foreach_add(tensors, 1)
540            self.assertEqual(res, tensors)
541
542            torch._foreach_add_(tensors, 1)
543            self.assertEqual(res, tensors)
544
545            # Regression test for https://github.com/pytorch/pytorch/issues/113156
546            torch._foreach_mul_(tensors, 1)
547
548    @onlyCUDA
549    @dtypes(torch.float32)
550    def test_foreach_check_stride_ignore_dims_of_one(self, device, dtype):
551        # default tensor stride is (9, 9, 3, 1).
552        tensor = torch.ones((2, 1, 3, 3), device=device, dtype=dtype)
553        strided_tensor = torch.ones(
554            (2, 1, 3, 3), device=device, dtype=dtype
555        ).as_strided((2, 1, 3, 3), (9, 1, 3, 1))
556        left_inputs = [tensor, strided_tensor]
557        right_inputs = [strided_tensor, tensor]
558        compare_result = tensor + strided_tensor
559        foreach_add_check_ = ForeachFuncWrapper(torch._foreach_add)
560        out = foreach_add_check_(
561            (left_inputs, right_inputs), is_cuda=True, expect_fastpath=True
562        )
563        for res in out:
564            self.assertEqual(res, compare_result)
565
566    @ops(
567        filter(lambda op: op.supports_out, foreach_binary_op_db),
568        dtypes=OpDTypes.supported,
569    )
570    def test_binary_op_scalar_with_overlapping_tensors(self, device, dtype, op):
571        foreach_op, ref = op.method_variant, op.ref
572        tensors = [torch.ones(1, 1, device=device, dtype=dtype).expand(2, 1, 3)]
573
574        if ref == torch.sub and dtype == torch.bool:
575            with self.assertRaisesRegex(RuntimeError, re.escape(_BOOL_SUB_ERR_MSG)):
576                [ref(t, 1) for t in tensors]
577            with self.assertRaisesRegex(RuntimeError, re.escape(_BOOL_SUB_ERR_MSG)):
578                foreach_op(tensors, 1)
579            return
580
581        expected = [ref(t, 1) for t in tensors]
582        res = foreach_op(tensors, 1)
583        self.assertEqual(res, expected)
584
585    @ops(
586        filter(lambda op: op.supports_out, foreach_binary_op_db),
587        allowed_dtypes=[torch.float],
588    )
589    def test_binary_op_scalar_with_different_tensor_dtypes(self, device, dtype, op):
590        foreach_op = op.method_variant
591        tensors = [
592            torch.tensor([1.1], dtype=torch.float, device=device),
593            torch.tensor([1], dtype=torch.long, device=device),
594        ]
595        runtime_error = None
596        try:
597            foreach_op(tensors, 1)
598        except RuntimeError as e:
599            runtime_error = e
600        self.assertIsNone(runtime_error)
601
602    @skipIfTorchDynamo("Different error msgs, TODO")
603    @ops(
604        filter(lambda op: op.supports_out, foreach_binary_op_db),
605        dtypes=OpDTypes.supported,
606    )
607    def test_binary_op_list_error_cases(self, device, dtype, op):
608        foreach_op, foreach_op_, ref, ref_ = (
609            op.method_variant,
610            op.inplace_variant,
611            op.ref,
612            op.ref_inplace,
613        )
614        tensors1 = []
615        tensors2 = []
616        ops_to_test = [foreach_op, foreach_op_]
617
618        # Empty lists
619        for fop in ops_to_test:
620            with self.assertRaisesRegex(
621                RuntimeError, "Tensor list must have at least one tensor."
622            ):
623                fop(tensors1, tensors2)
624
625        # One empty list
626        tensors1.append(torch.tensor([1], device=device, dtype=dtype))
627        for fop in ops_to_test:
628            with self.assertRaisesRegex(
629                RuntimeError,
630                "Tensor list must have same number of elements as scalar list.",
631            ):
632                fop(tensors1, tensors2)
633
634        # Lists have different amount of tensors
635        tensors2.append(torch.tensor([1], device=device))
636        tensors2.append(torch.tensor([1], device=device))
637        for fop in ops_to_test:
638            with self.assertRaisesRegex(
639                RuntimeError,
640                "Tensor lists must have the same number of tensors, got 1 and 2",
641            ):
642                fop(tensors1, tensors2)
643            with self.assertRaisesRegex(
644                RuntimeError,
645                "Tensor lists must have the same number of tensors, got 2 and 1",
646            ):
647                fop(tensors2, tensors1)
648
649        # Corresponding tensors with different sizes that aren't compatible with broadcast
650        # If sizes are different then foreach chooses slow path, thus error messages are expected
651        # to be the same as torch regular function.
652        tensors1 = [torch.zeros(10, 10, device=device, dtype=dtype) for _ in range(10)]
653        tensors2 = [torch.ones(11, 11, device=device, dtype=dtype) for _ in range(10)]
654
655        if dtype == torch.bool and foreach_op == torch._foreach_sub:
656            for fop in ops_to_test:
657                with self.assertRaisesRegex(RuntimeError, re.escape(_BOOL_SUB_ERR_MSG)):
658                    fop(tensors1, tensors2)
659            return
660        with self.assertRaisesRegex(
661            RuntimeError,
662            r"The size of tensor a \(10\) must match the size of tensor b \(11\) at non-singleton dimension 1",
663        ):
664            foreach_op(tensors1, tensors2)
665        with self.assertRaisesRegex(
666            RuntimeError,
667            r"The size of tensor a \(10\) must match the size of tensor b \(11\) at non-singleton dimension 1",
668        ):
669            foreach_op_(tensors1, tensors2)
670
671        # different devices
672        if self.device_type == "cuda" and torch.cuda.device_count() > 1:
673            tensor1 = torch.zeros(10, 10, device="cuda:0", dtype=dtype)
674            tensor2 = torch.ones(10, 10, device="cuda:1", dtype=dtype)
675            with self.assertRaisesRegex(
676                RuntimeError, "Expected all tensors to be on the same device"
677            ):
678                foreach_op([tensor1], [tensor2])
679            if (
680                dtype in integral_types_and(torch.bool)
681                and foreach_op == torch._foreach_div
682            ):
683                with self.assertRaisesRegex(RuntimeError, "result type"):
684                    foreach_op_([tensor1], [tensor2])
685            else:
686                with self.assertRaisesRegex(
687                    RuntimeError, "Expected all tensors to be on the same device"
688                ):
689                    foreach_op_([tensor1], [tensor2])
690
691    @unittest.skipIf(not torch.cuda.is_available(), "CUDA not found")
692    @ops(
693        filter(lambda op: op.supports_out, foreach_binary_op_db),
694        dtypes=OpDTypes.supported,
695    )
696    def test_binary_op_list_slow_path(self, device, dtype, op):
697        foreach_op, native_op, foreach_op_, native_op_ = self._get_funcs(op)
698        # 0-strides
699        tensor1 = make_tensor((10, 10), dtype=dtype, device=device)
700        tensor2 = make_tensor((1,), device=device, dtype=dtype).expand_as(tensor1)
701        inputs = ([tensor1], [tensor2])
702        self._binary_test(
703            dtype,
704            foreach_op,
705            native_op,
706            inputs,
707            is_fastpath=False,
708            is_inplace=False,
709            alpha=None,
710            scalar_self_arg=False,
711        )
712        self._binary_test(
713            dtype,
714            foreach_op_,
715            native_op_,
716            inputs,
717            is_fastpath=False,
718            is_inplace=True,
719            alpha=None,
720            scalar_self_arg=False,
721        )
722
723        # different strides
724        tensor1 = torch.zeros(10, 10, device=device, dtype=dtype)
725        tensor2 = torch.ones(10, 10, device=device, dtype=dtype)
726        inputs = ([tensor1], [tensor2.t()])
727        self._binary_test(
728            dtype,
729            foreach_op,
730            native_op,
731            inputs,
732            is_fastpath=False,
733            is_inplace=False,
734            alpha=None,
735            scalar_self_arg=False,
736        )
737        self._binary_test(
738            dtype,
739            foreach_op_,
740            native_op_,
741            inputs,
742            is_fastpath=False,
743            is_inplace=True,
744            alpha=None,
745            scalar_self_arg=False,
746        )
747
748        # non contiguous
749        tensor1 = make_tensor(
750            (5, 2, 1, 3), device=device, dtype=dtype, noncontiguous=True
751        )
752        tensor2 = make_tensor(
753            (5, 2, 1, 3), device=device, dtype=dtype, noncontiguous=True
754        )
755        self.assertFalse(tensor1.is_contiguous())
756        self.assertFalse(tensor2.is_contiguous())
757        inputs = ([tensor1], [tensor2])
758        self._binary_test(
759            dtype,
760            foreach_op,
761            native_op,
762            inputs,
763            is_fastpath=False,
764            is_inplace=False,
765            alpha=None,
766            scalar_self_arg=False,
767        )
768        self._binary_test(
769            dtype,
770            foreach_op_,
771            native_op_,
772            inputs,
773            is_fastpath=False,
774            is_inplace=True,
775            alpha=None,
776            scalar_self_arg=False,
777        )
778
779        # sliced tensor
780        tensor1 = make_tensor((5, 2, 1, 3), device=device, dtype=dtype)
781        tensor2 = make_tensor((5, 2, 1, 3 * 7), device=device, dtype=dtype)[
782            :, :, :, ::7
783        ]
784        inputs = ([tensor1], [tensor2])
785        self._binary_test(
786            dtype,
787            foreach_op,
788            native_op,
789            inputs,
790            is_fastpath=False,
791            is_inplace=False,
792            alpha=None,
793            scalar_self_arg=False,
794        )
795        self._binary_test(
796            dtype,
797            foreach_op_,
798            native_op_,
799            inputs,
800            is_fastpath=False,
801            is_inplace=True,
802            alpha=None,
803            scalar_self_arg=False,
804        )
805
806    @ops(
807        filter(lambda op: op.supports_out, foreach_binary_op_db),
808        dtypes=floating_types_and(torch.half, torch.bfloat16),
809    )
810    def test_binary_op_float_inf_nan(self, device, dtype, op):
811        inputs = (
812            [
813                torch.tensor([float("inf")], device=device, dtype=dtype),
814                torch.tensor([-float("inf")], device=device, dtype=dtype),
815                torch.tensor([float("nan")], device=device, dtype=dtype),
816                torch.tensor([float("nan")], device=device, dtype=dtype),
817            ],
818            [
819                torch.tensor([-float("inf")], device=device, dtype=dtype),
820                torch.tensor([float("inf")], device=device, dtype=dtype),
821                torch.tensor([float("inf")], device=device, dtype=dtype),
822                torch.tensor([float("nan")], device=device, dtype=dtype),
823            ],
824        )
825        op, ref, inplace_op, inplace_ref = self._get_funcs(op)
826        self._binary_test(
827            dtype, op, ref, inputs, True, False, alpha=None, scalar_self_arg=False
828        )
829        self._binary_test(
830            dtype,
831            inplace_op,
832            inplace_ref,
833            inputs,
834            True,
835            True,
836            alpha=None,
837            scalar_self_arg=False,
838        )
839
840    # note: Below three tests (postfixed with `_tensors_on_different_devices`)
841    # checks whether foreach works with lists of tensors on different devices
842    # but tensors of the same index are on the same device, e.g., ['cuda', 'cpu].
843    @onlyCUDA
844    @ops(foreach_unary_op_db)
845    def test_unary_op_tensors_on_different_devices(self, device, dtype, op):
846        method, ref, inplace_method, ref_inplace = self._get_funcs(op)
847        # tensors: ['cuda', 'cpu]
848        tensors = next(
849            iter(
850                op.sample_inputs(
851                    device,
852                    dtype,
853                    num_input_tensors=[2],
854                    allow_higher_dtype_scalars=True,
855                )
856            )
857        ).input
858        tensors[1] = tensors[1].to("cpu")
859        if not op.supports_out:
860            try:
861                actual = method((tensors,), False, False, zero_size=False)
862            except RuntimeError as e:
863                with self.assertRaisesRegex(type(e), str(e).splitlines()[0]):
864                    ref((tensors,))
865            else:
866                expected = ref((tensors,))
867                self.assertEqual(expected, actual)
868
869        try:
870            inplace_method((tensors,), False, False, zero_size=False)
871        except RuntimeError as e:
872            with self.assertRaisesRegex(type(e), str(e).splitlines()[0]):
873                ref_inplace((tensors,))
874        else:
875            if not op.supports_out:
876                self.assertEqual(expected, tensors)
877            else:
878                self.assertEqual([torch.zeros_like(t) for t in tensors], tensors)
879
880    @onlyCUDA
881    @ops(filter(lambda op: op.supports_out, foreach_binary_op_db))
882    def test_binary_op_tensors_on_different_devices(self, device, dtype, op):
883        _cuda_tensors = next(
884            iter(
885                op.sample_inputs(
886                    device,
887                    dtype,
888                    num_input_tensors=[2],
889                    same_size=True,
890                    allow_higher_dtype_scalars=True,
891                )
892            )
893        ).input
894        _cpu_tensors = next(
895            iter(
896                op.sample_inputs(
897                    "cpu",
898                    dtype,
899                    num_input_tensors=[2],
900                    same_size=True,
901                    allow_higher_dtype_scalars=True,
902                )
903            )
904        ).input
905        tensors1, tensors2 = list(zip(_cuda_tensors, _cpu_tensors))
906
907        foreach_op, foreach_op_ = op.method_variant, op.inplace_variant
908        native_op, native_op_ = op.ref, op.ref_inplace
909        try:
910            actual = foreach_op(tensors1, tensors2)
911        except RuntimeError as e:
912            with self.assertRaisesRegex(type(e), re.escape(str(e).splitlines()[0])):
913                [native_op(t1, t2) for t1, t2 in zip(tensors1, tensors2)]
914        else:
915            expected = [native_op(t1, t2) for t1, t2 in zip(tensors1, tensors2)]
916            self.assertEqual(expected, actual)
917        try:
918            foreach_op_(tensors1, tensors2)
919        except RuntimeError as e:
920            with self.assertRaisesRegex(type(e), re.escape(str(e).splitlines()[0])):
921                [native_op_(t1, t2) for t1, t2 in zip(tensors1, tensors2)]
922        else:
923            self.assertEqual(actual, tensors1)
924
925    @onlyCUDA
926    @ops(foreach_pointwise_op_db, allowed_dtypes=floating_types())
927    def test_pointwise_op_tensors_on_different_devices(self, device, dtype, op):
928        # tensors1: ['cuda', 'cpu]
929        # tensors2: ['cuda', 'cpu]
930        # tensors3: ['cuda', 'cpu]
931        # first tensorlist is zero-size when float32
932        _cuda_tensors = list(
933            op.sample_inputs(
934                device,
935                dtype,
936                num_input_tensors=[3],
937                same_size=True,
938                allow_higher_dtype_scalars=True,
939            )
940        )[int(dtype == torch.float32)].input
941        _cpu_tensors = next(
942            iter(
943                op.sample_inputs(
944                    "cpu",
945                    dtype,
946                    num_input_tensors=[3],
947                    same_size=True,
948                    allow_higher_dtype_scalars=True,
949                )
950            )
951        ).input
952        tensors1, tensors2, tensors3 = list(zip(_cuda_tensors, _cpu_tensors))
953
954        foreach_op, foreach_op_, native_op = (
955            op.method_variant,
956            op.inplace_variant,
957            op.ref,
958        )
959        actual = foreach_op(tensors1, tensors2, tensors3)
960        expected = [native_op(*_cuda_tensors), native_op(*_cpu_tensors)]
961        self.assertEqual(expected, actual)
962
963        # note(mkozuki): Limiting dtypes to FP32&FP64, we can safely run inplace ops.
964        foreach_op_(tensors1, tensors2, tensors3)
965        self.assertEqual(expected, tensors1)
966
967    # note: BFloat16 has the same number of exponent bits as FP32
968    # so if squared L2 norm overflows in BF16, then it also overflows in FP32.
969    @onlyCUDA
970    @ops(
971        [o for o in foreach_reduce_op_db if "norm" in o.name],
972        allowed_dtypes=(torch.half, torch.bfloat16),
973    )
974    def test_foreach_l2_large_value_input(self, device, dtype, op):
975        ord, N = 2, 10
976        max_value = torch.finfo(dtype).max
977        scaler = torch.tensor([max_value]).sqrt().to(device=device, dtype=dtype)
978        inputs = (
979            [
980                t * scaler
981                for t in next(
982                    iter(
983                        op.sample_inputs(
984                            device,
985                            dtype,
986                            requries_grad=True,
987                            num_input_tensors=[N],
988                            low=1,
989                        )
990                    )
991                ).input
992            ][:-1],
993        )
994        # make sure that the min. of squared L2 norm value per tensor is greater than the max value of `dtype`.
995        self.assertTrue(scaler * scaler * N > max_value)
996        fn, ref_fn, *_ = self._get_funcs(op)
997        actual = fn(
998            inputs, is_cuda=True, expect_fastpath=True, ord=ord, zero_size=False
999        )
1000        expect = ref_fn(inputs, ord=ord)
1001
1002        if dtype == torch.float16:
1003            # making sure the reference L2 norm values are in the range of FP16.
1004            self.assertFalse(any(torch.isinf(e) for e in expect))
1005        else:
1006            self.assertTrue(
1007                all(
1008                    inputs[0][i].numel() == 0 or torch.isinf(e)
1009                    for i, e in enumerate(expect)
1010                )
1011            )
1012        self.assertEqual(expect, actual, equal_nan=False)
1013
1014    @onlyCUDA
1015    @ops(foreach_reduce_op_db, allowed_dtypes=floating_types())
1016    @parametrize("use_cuda_graph", (False, True))
1017    def test_big_num_tensors(self, device, dtype, op, use_cuda_graph):
1018        N = 600
1019        tensorlist = [
1020            make_tensor((2, 3), dtype=dtype, device=device, noncontiguous=False)
1021            for _ in range(N)
1022        ]
1023        fn, ref_fn, *_ = self._get_funcs(op)
1024
1025        import math
1026
1027        if op.name == "_foreach_norm":
1028            ords = (1, 2, math.inf)
1029        else:
1030            ords = (None,)
1031
1032        for ord in ords:
1033            kwargs = {"ord": ord} if ord else {}
1034            if not use_cuda_graph:
1035                actual = fn(
1036                    inputs=[tensorlist],
1037                    is_cuda=True,
1038                    expect_fastpath=True,
1039                    zero_size=False,
1040                    **kwargs,
1041                )
1042            else:
1043                # When using CUDA graphs and the tensor metadata doesn't fit in
1044                # the static kernel argument space, multi_tensor_apply creates
1045                # the launch arguments once, uses cudaUserObject_t to tie its
1046                # lifetime to the graph, and reuses it throughout replays. This
1047                # test verifies multi_tensor_apply's behavior in the scenario.
1048                g = torch.cuda.CUDAGraph()
1049                with torch.cuda.graph(g):
1050                    actual = fn.func(tensorlist, **kwargs)
1051                g.replay()
1052            expect = ref_fn(inputs=[tensorlist], **kwargs)
1053
1054            self.assertEqual(expect, actual, equal_nan=True)
1055
1056    @onlyCUDA
1057    @ops(foreach_reduce_op_db)
1058    def test_foreach_reduce_large_input(self, device, dtype, op):
1059        # test inputs larger than kChunkSize = 65536
1060        N = 65536 * 2
1061        disable_fastpath = False
1062        kwargs = {}
1063        if op.name == "_foreach_norm":
1064            ord = 2
1065            disable_fastpath = not (
1066                ord in (1, 2)
1067                and dtype in floating_types_and(torch.half, torch.bfloat16)
1068            )
1069            kwargs["ord"] = ord
1070
1071        inputs = ([make_tensor((N,), dtype=dtype, device=device, noncontiguous=False)],)
1072        wrapped_op, ref, _, _ = self._get_funcs(op)
1073        self.assertEqual(
1074            ref(inputs, **kwargs),
1075            wrapped_op(
1076                inputs, self.is_cuda, not disable_fastpath, zero_size=False, **kwargs
1077            ),
1078        )
1079
1080    @onlyCUDA
1081    @ops(
1082        foreach_unary_op_db
1083        + foreach_binary_op_db
1084        + foreach_pointwise_op_db
1085        + foreach_other_op_db,
1086        dtypes=(torch.float,),
1087    )
1088    def test_inplace_foreach_leaf_check_and_grad_fn(self, device, dtype, op):
1089        inplace_op = op.inplace_variant
1090        if inplace_op is None:
1091            self.skipTest("no in-place op available")
1092
1093        sample = next(
1094            iter(
1095                op.sample_inputs(
1096                    dtype=dtype, device=device, num_input_tensors=[2], same_size=True
1097                )
1098            )
1099        )
1100        sample.input[0].requires_grad_(True)
1101        with self.assertRaisesRegex(RuntimeError, "a leaf Variable that requires grad"):
1102            inplace_op(sample.input, *sample.args)
1103        sample.input[1].requires_grad_(True)
1104        with self.assertRaisesRegex(RuntimeError, "a leaf Variable that requires grad"):
1105            inplace_op(sample.input, *sample.args)
1106
1107        _tensors = [
1108            t.clone().detach().requires_grad_(i == 0)
1109            for i, t in enumerate(sample.input)
1110        ]
1111        tensors = [t.clone() for t in _tensors]
1112        inplace_op(tensors, *sample.args)
1113        self.assertIsNotNone(tensors[0].grad_fn)
1114        self.assertIsNone(tensors[1].grad_fn)
1115
1116    @onlyCUDA
1117    @ops(
1118        filter(
1119            lambda op: op.supports_out,
1120            foreach_unary_op_db
1121            + foreach_binary_op_db
1122            + foreach_pointwise_op_db
1123            + foreach_other_op_db,
1124        ),
1125        dtypes=(torch.float,),
1126    )
1127    def test_outplace_with_invalid_grads(self, device, dtype, op):
1128        func, *_ = self._get_funcs(op)
1129        sample = next(
1130            iter(
1131                op.sample_inputs(
1132                    dtype=dtype,
1133                    device=device,
1134                    requires_grad=True,
1135                    num_input_tensors=[2],
1136                    same_size=True,
1137                )
1138            )
1139        )
1140        self.assertTrue(all(t.requires_grad for t in sample.input))
1141        (out1, out2) = func(
1142            [sample.input, *sample.args],
1143            is_cuda=False,
1144            expect_fastpath=False,
1145            **sample.kwargs,
1146        )
1147        out1.backward(torch.ones_like(out1))
1148        self.assertIsNotNone(sample.input[0].grad)
1149        self.assertIsNone(sample.input[1].grad)
1150
1151    @ops(
1152        filter(
1153            lambda op: op.backward_requires_result,
1154            foreach_unary_op_db
1155            + foreach_binary_op_db
1156            + foreach_pointwise_op_db
1157            + foreach_other_op_db,
1158        ),
1159        dtypes=(torch.float32,),
1160    )
1161    def test_lifetime_of_grad_fn_when_result_is_saved(self, device, dtype, op):
1162        def get_ref(func, sample):
1163            class Foo:
1164                pass
1165
1166            out = func(
1167                (sample.input, *sample.args),
1168                is_cuda=False,
1169                expect_fastpath=False,
1170                **sample.kwargs,
1171            )
1172            foo = Foo()
1173            meta_dict = out[0].grad_fn.metadata
1174            meta_dict[0] = foo
1175            ref = weakref.ref(foo)
1176            return out, ref
1177
1178        def _test(func, sample):
1179            out, ref = get_ref(func, sample)
1180            self.assertIsNotNone(ref())
1181            del out
1182            self.assertIsNone(ref())
1183
1184        func = self._get_funcs(op)[0]
1185        for sample in op.sample_inputs(
1186            device, dtype, requires_grad=True, num_input_tensors=[1]
1187        ):
1188            for key in ("is_fastpath", "disable_fastpath"):
1189                if key in sample.kwargs:
1190                    del sample.kwargs[key]
1191            # note: `_foreach_pow.Scalar` and `_foreach_pow.ScalarList` don't depend on `result`
1192            # see: https://github.com/pytorch/pytorch/blob/5403c777/tools/autograd/derivatives.yaml#L3048-L3049
1193            if op.name == "_foreach_pow":
1194                if (
1195                    isinstance(sample.args[0], list)
1196                    and isinstance(sample.args[0][0], Number)
1197                ) or (
1198                    isinstance(sample.args[0], Number)
1199                    and not isinstance(sample.args[0], float)
1200                ):
1201                    continue
1202                if isinstance(sample.args[0], float):
1203                    new_args = (sample.input,)
1204                    sample.input = sample.args[0]
1205                    sample.args = new_args
1206            _test(func, sample)
1207
1208    @unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
1209    def test_tensors_grouping(self):
1210        num_tensors_per_list = 10
1211        num_devices = torch.cuda.device_count()
1212        dtypes = (torch.float16, torch.float32, torch.float64)
1213        list1 = [
1214            torch.tensor(
1215                i,
1216                device=torch.device("cuda", random.randint(0, num_devices - 1)),
1217                dtype=dtypes[random.randint(0, 2)],
1218            )
1219            for i in range(num_tensors_per_list)
1220        ]
1221        list2 = [None for _ in list1]
1222        list3 = [torch.rand_like(t) for t in list1]
1223        nested_tensorlists = [list1, list2, list3]
1224        grouped_tensors = torch.utils._foreach_utils._group_tensors_by_device_and_dtype(
1225            nested_tensorlists, with_indices=True
1226        )
1227        num_tensors_seen = 0
1228        for (device, dtype), ([l1, l2, l3], indices) in grouped_tensors.items():
1229            for t in itertools.chain(l1, l3):
1230                self.assertEqual(t.device, device)
1231                self.assertEqual(t.dtype, dtype)
1232                num_tensors_seen += 1
1233            self.assertEqual(len(l1), len(l2))
1234            self.assertTrue(all(p is None for p in l2))
1235            for i, index in enumerate(indices):
1236                self.assertEqual(l1[i], list1[index])
1237                self.assertEqual(l2[i], list2[index])
1238                self.assertEqual(l3[i], list3[index])
1239        self.assertEqual(num_tensors_seen, 2 * num_tensors_per_list)
1240
1241    @onlyCUDA
1242    def test_0dim_tensor_overload_cpu_ok(self):
1243        tensors = [torch.ones((), device="cuda", dtype=torch.float32) for _ in range(2)]
1244        scalar_cpu_tensor = torch.tensor(4.0, device="cpu")
1245
1246        # For mul and div, the scalar is allowed to be on CPU too
1247        actual = torch._foreach_mul(tensors, scalar_cpu_tensor)
1248        self.assertEqual(actual, [t.mul(scalar_cpu_tensor) for t in tensors])
1249        actual = torch._foreach_div(tensors, scalar_cpu_tensor)
1250        self.assertEqual(actual, [t.div(scalar_cpu_tensor) for t in tensors])
1251
1252    @onlyCUDA
1253    def test_div_reciprocal(self):
1254        expect_m, expect_e = torch.frexp(
1255            torch.div(torch.tensor(0.1, device="cuda"), 10.0)
1256        )
1257        actual_m, actual_e = torch.frexp(
1258            torch._foreach_div([torch.tensor(0.1, device="cuda")], [10.0])[0]
1259        )
1260        self.assertEqual(expect_m, actual_m)
1261        self.assertEqual(expect_e, actual_e)
1262
1263    @onlyCUDA
1264    def test_0dim_tensor_overload_exception(self):
1265        # check exceptions of fast path
1266        tensors = [
1267            make_tensor((2, 2), dtype=torch.float, device="cuda") for _ in range(2)
1268        ]
1269        with self.assertRaisesRegex(RuntimeError, "scalar tensor expected to be on"):
1270            torch._foreach_add(tensors, torch.tensor(1.0, device="cpu"), alpha=1.0)
1271
1272        tensors = [
1273            make_tensor((2, 2), dtype=torch.float, device=d) for d in ("cpu", "cuda")
1274        ]
1275        with self.assertRaisesRegex(
1276            RuntimeError, "scalar tensor expected to be 0 dim but"
1277        ):
1278            torch._foreach_mul(tensors, torch.tensor([1.0, 1.0], device="cuda"))
1279        with self.assertRaisesRegex(
1280            RuntimeError, "scalar tensor expected to be 0 dim but"
1281        ):
1282            torch._foreach_add(tensors, torch.tensor([1.0, 1.0], device="cuda"))
1283
1284    @onlyCUDA
1285    @ops(filter(lambda op: op.name == "_foreach_copy", foreach_binary_op_db))
1286    def test_foreach_copy_with_multi_device_inputs(self, device, dtype, op):
1287        foreach_copy_ = op.inplace_variant
1288        copy_ = op.ref_inplace
1289        for non_blocking in (False, True):
1290            for sample in op.sample_inputs(
1291                device, dtype, noncontiguous=False, allow_higher_dtype_scalars=True
1292            ):
1293                with torch.no_grad():
1294                    ref_input = [t.clone().detach() for t in sample.input]
1295                foreach_copy_(sample.input, sample.args[0], non_blocking)
1296                for t, s in zip(ref_input, sample.args[0]):
1297                    copy_(t, s, non_blocking)
1298                self.assertEqual(sample.input, ref_input)
1299                if torch.cuda.device_count() > 1:
1300                    device = torch.device("cuda", 1)
1301                    rhs_tensors = [t.to(device) for t in sample.args[0]]
1302                    foreach_copy_(sample.input, rhs_tensors, non_blocking)
1303                    for t, s in zip(ref_input, rhs_tensors):
1304                        copy_(t, s, non_blocking)
1305                    self.assertEqual(ref_input, sample.input)
1306
1307    @onlyCUDA
1308    @ops(filter(lambda op: op.name == "_foreach_copy", foreach_binary_op_db))
1309    def test_foreach_copy_with_multi_dtypes(self, device, dtype, op):
1310        # check (a) multi_tensor_apply is called and (b) numerical parity with for-loop and Tensor.copy_
1311        foreach_copy_ = ForeachFuncWrapper(op.inplace_variant)
1312        for sample in op.sample_inputs(
1313            device, dtype, noncontiguous=False, allow_higher_dtype_scalars=True
1314        ):
1315            for src_dtype in floating_types_and(torch.half, torch.bfloat16):
1316                if src_dtype == dtype:
1317                    continue
1318                self_tensors = [t.clone() for t in sample.input]
1319                src_tensors = [t.to(src_dtype) for t in self_tensors]
1320                out = foreach_copy_(
1321                    (self_tensors, src_tensors), is_cuda=True, expect_fastpath=True
1322                )
1323                ref_out = [
1324                    torch.empty_like(t).copy_(s)
1325                    for t, s in zip(self_tensors, src_tensors)
1326                ]
1327                for t, ref_t in zip(out, ref_out):
1328                    self.assertTrue(torch.equal(t, ref_t))
1329
1330    # Test reverse-mode & forward-mode AD if supported.
1331    @onlyCUDA
1332    @ops(
1333        foreach_unary_op_db
1334        + foreach_binary_op_db
1335        + foreach_pointwise_op_db
1336        + foreach_reduce_op_db
1337        + foreach_other_op_db,
1338        dtypes=OpDTypes.supported,
1339        allowed_dtypes=(torch.float64, torch.complex128),
1340    )
1341    @parametrize(
1342        "inplace", (False, True), name_fn=lambda x: "inplace" if x else "outplace"
1343    )
1344    def test_autodiff(self, device, dtype, op, inplace):
1345        if (not inplace) and not op.supports_out:
1346            self.skipTest("out-of-place not implemented")
1347        if inplace and op.has_no_in_place:
1348            self.skipTest("in-place not implemented")
1349        if not (
1350            op.supports_autograd
1351            or op.supports_inplace_autograd
1352            or op.supports_forward_ad
1353        ):
1354            self.skipTest("neither reverse mode nor forward mode supported")
1355
1356        # note(crcrpar): without this, some unary functions fail, unlike inplace and/or complex.
1357        if (
1358            (not inplace)
1359            and dtype == torch.float64
1360            and op.name
1361            in (
1362                "_foreach_acos",
1363                "_foreach_asin",
1364                "_foreach_log10",
1365                "_foreach_log1p",
1366                "_foreach_log2",
1367                "_foreach_log",
1368                "_foreach_pow",
1369                "_foreach_sqrt",
1370            )
1371        ):
1372            value_range = {"low": 0.5, "high": 1.0}
1373        else:
1374            value_range = {}
1375        for sample in op.sample_inputs(
1376            device,
1377            dtype,
1378            requires_grad=True,
1379            num_input_tensors=[5],
1380            allow_higher_dtype_scalars=True,
1381            **value_range,
1382        ):
1383            # Skip `_foreach_pow.ScalarAndTensor(Scalar, Tensor[])`
1384            if op.name == "_foreach_pow" and isinstance(sample.input, Number):
1385                continue
1386
1387            func = None
1388            if inplace:
1389                # Call `clone` to avoid inplace modifications likewise
1390                # `torch.testing._internal.common_utils.TestGradients._get_safe_inplace`
1391                def inplace_func(*tensorlist):
1392                    kwargs = (
1393                        {"alpha": sample.kwargs["alpha"]}
1394                        if "alpha" in sample.kwargs
1395                        else {}
1396                    )
1397                    op.inplace_variant(
1398                        tuple(t.clone() for t in tensorlist), *sample.args, **kwargs
1399                    )
1400                    return tensorlist
1401
1402                func = inplace_func
1403            else:
1404
1405                def outplace_func(*tensorlist):
1406                    kwargs = (
1407                        {"alpha": sample.kwargs["alpha"]}
1408                        if "alpha" in sample.kwargs
1409                        else {}
1410                    )
1411                    return op.method_variant(tensorlist, *sample.args, **kwargs)
1412
1413                func = outplace_func
1414
1415            working_sample, err_msg_pattern = check_autodiff_sample(
1416                op, sample, dtype, inplace
1417            )
1418
1419            def call_gradcheck():
1420                gradcheck(
1421                    func,
1422                    sample.input,
1423                    raise_exception=True,
1424                    check_forward_ad=op.supports_forward_ad,
1425                    check_batched_forward_grad=False,
1426                    check_backward_ad=op.supports_autograd,
1427                    check_batched_grad=False,
1428                )
1429
1430            if not working_sample:
1431                if not err_msg_pattern:
1432                    # lhs of float64 and rhs of complex.
1433                    continue
1434                with self.assertRaisesRegex(RuntimeError, re.escape(err_msg_pattern)):
1435                    call_gradcheck()
1436                continue
1437            call_gradcheck()
1438
1439            # Test per-tensor `grad_fn` behavior.
1440            if inplace and op.supports_inplace_autograd:
1441                # per-tensor `grad_fn` check.
1442                hook_buffer = []
1443
1444                def get_grad_fn_hook(i):
1445                    def hook(grad_inputs, grad_outputs) -> None:
1446                        hook_buffer.append(i)
1447
1448                    return hook
1449
1450                _inputs = [t.clone().detach().requires_grad_() for t in sample.input]
1451                inputs = [t.clone() for t in _inputs]
1452                kwargs = (
1453                    {"alpha": sample.kwargs["alpha"]}
1454                    if "alpha" in sample.kwargs
1455                    else {}
1456                )
1457                op.inplace_variant(inputs, *sample.args, **kwargs)
1458
1459                self.assertEqual(len({t.grad_fn for t in inputs}), len(inputs))
1460
1461                for i, t in enumerate(inputs):
1462                    t.grad_fn.register_hook(get_grad_fn_hook(i))
1463
1464                torch.autograd.grad(
1465                    inputs[0],
1466                    inputs=(_inputs[0],),
1467                    grad_outputs=(torch.rand_like(inputs[0]),),
1468                    retain_graph=True,
1469                )
1470                self.assertEqual(hook_buffer, [0])
1471                hook_buffer.clear()
1472
1473                # tensors have different shapes.
1474                sum_of_cloned_tensors = torch.cat([t.view(-1) for t in inputs]).sum()
1475                grad_output = torch.rand_like(sum_of_cloned_tensors)
1476                torch.autograd.grad(
1477                    sum_of_cloned_tensors,
1478                    inputs=tuple(_inputs),
1479                    grad_outputs=(grad_output,),
1480                    retain_graph=False,
1481                )
1482                self.assertEqual(hook_buffer, list(reversed(range(len(inputs)))))
1483
1484
1485# TODO(crcrpar): Hide this inside torch/testing/_internal.
1486# would end up adding another layer to `foreach_inputs_sample_func.__call__`
1487# so that we can use this function as something like the first argument of `filter` function.
1488# Even after moving this function to testing, I personally think it'd be better to check the error message.
1489def check_autodiff_sample(op, sample, dtype, is_inplace):
1490    if op.name == "_foreach_abs" and is_inplace and dtype == torch.complex128:
1491        return False, "In-place abs is not supported for complex tensors."
1492    if op.name == "_foreach_sub" and (
1493        (
1494            isinstance(sample.args[-1], list)
1495            and any(isinstance(a, bool) for a in sample.args[-1])
1496        )
1497        or isinstance(sample.args[-1], bool)
1498    ):
1499        return False, _BOOL_SUB_ERR_MSG
1500    if op.name == "_foreach_norm" and (not is_inplace):
1501        return (
1502            False,
1503            "Trying to set a forward gradient that has a different size than that of the original Tensor, "
1504            "this is not supported. Tensor is of size [] while the given forward gradient is of size [1, 1].",
1505        )
1506    rhs_arg_has_complex_number = sample.args and (
1507        (
1508            isinstance(sample.args[-1], list)
1509            and any(isinstance(a, complex) for a in sample.args[-1])
1510        )
1511        or (isinstance(sample.args[-1], complex))
1512    )
1513    if rhs_arg_has_complex_number and dtype == torch.float64:
1514        if op.name in (
1515            "_foreach_clamp_max",
1516            "_foreach_clamp_min",
1517            "_foreach_maximum",
1518            "_foreach_minimum",
1519        ):
1520            return False, "clamp is not supported for complex types"
1521        if op.name == "_foreach_lerp" and is_inplace:
1522            return False, "value cannot be converted to type double without overflow"
1523        if not is_inplace:
1524            return False, ""
1525        else:
1526            if op.name == "_foreach_pow":
1527                return False, "Found dtype Double but expected ComplexDouble"
1528            if op.name in (
1529                "_foreach_add",
1530                "_foreach_sub",
1531                "_foreach_mul",
1532                "_foreach_div",
1533            ):
1534                return (
1535                    False,
1536                    "result type ComplexDouble can't be cast to the desired output type Double",
1537                )
1538    return True, ""
1539
1540
1541instantiate_device_type_tests(TestForeach, globals())
1542
1543
1544if __name__ == "__main__":
1545    run_tests()
1546