xref: /aosp_15_r20/external/pytorch/test/test_foreach.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: mta"]
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerimport itertools
4*da0073e9SAndroid Build Coastguard Workerimport os
5*da0073e9SAndroid Build Coastguard Workerimport random
6*da0073e9SAndroid Build Coastguard Workerimport re
7*da0073e9SAndroid Build Coastguard Workerimport unittest
8*da0073e9SAndroid Build Coastguard Workerimport weakref
9*da0073e9SAndroid Build Coastguard Workerfrom contextlib import nullcontext
10*da0073e9SAndroid Build Coastguard Workerfrom numbers import Number
11*da0073e9SAndroid Build Coastguard Worker
12*da0073e9SAndroid Build Coastguard Workerimport torch
13*da0073e9SAndroid Build Coastguard Workerfrom torch.testing import make_tensor
14*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._comparison import default_tolerances
15*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_cuda import TEST_MULTIGPU
16*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_device_type import (
17*da0073e9SAndroid Build Coastguard Worker    dtypes,
18*da0073e9SAndroid Build Coastguard Worker    instantiate_device_type_tests,
19*da0073e9SAndroid Build Coastguard Worker    onlyCUDA,
20*da0073e9SAndroid Build Coastguard Worker    OpDTypes,
21*da0073e9SAndroid Build Coastguard Worker    ops,
22*da0073e9SAndroid Build Coastguard Worker)
23*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_dtype import (
24*da0073e9SAndroid Build Coastguard Worker    all_types_and_complex_and,
25*da0073e9SAndroid Build Coastguard Worker    floating_types,
26*da0073e9SAndroid Build Coastguard Worker    floating_types_and,
27*da0073e9SAndroid Build Coastguard Worker    integral_types_and,
28*da0073e9SAndroid Build Coastguard Worker)
29*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_methods_invocations import (
30*da0073e9SAndroid Build Coastguard Worker    foreach_binary_op_db,
31*da0073e9SAndroid Build Coastguard Worker    foreach_other_op_db,
32*da0073e9SAndroid Build Coastguard Worker    foreach_pointwise_op_db,
33*da0073e9SAndroid Build Coastguard Worker    foreach_reduce_op_db,
34*da0073e9SAndroid Build Coastguard Worker    foreach_unary_op_db,
35*da0073e9SAndroid Build Coastguard Worker)
36*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import (
37*da0073e9SAndroid Build Coastguard Worker    gradcheck,
38*da0073e9SAndroid Build Coastguard Worker    parametrize,
39*da0073e9SAndroid Build Coastguard Worker    run_tests,
40*da0073e9SAndroid Build Coastguard Worker    skipIfRocmVersionLessThan,
41*da0073e9SAndroid Build Coastguard Worker    skipIfTorchDynamo,
42*da0073e9SAndroid Build Coastguard Worker    TEST_WITH_ROCM,
43*da0073e9SAndroid Build Coastguard Worker    TestCase,
44*da0073e9SAndroid Build Coastguard Worker)
45*da0073e9SAndroid Build Coastguard Worker
46*da0073e9SAndroid Build Coastguard Worker
47*da0073e9SAndroid Build Coastguard Worker_BOOL_SUB_ERR_MSG = "Subtraction, the `-` operator"
48*da0073e9SAndroid Build Coastguard Worker
49*da0073e9SAndroid Build Coastguard Worker
50*da0073e9SAndroid Build Coastguard Workerclass RegularFuncWrapper:
51*da0073e9SAndroid Build Coastguard Worker    def __init__(self, func):
52*da0073e9SAndroid Build Coastguard Worker        self.func = func
53*da0073e9SAndroid Build Coastguard Worker
54*da0073e9SAndroid Build Coastguard Worker    def __call__(self, inputs, scalars=None, **kwargs):
55*da0073e9SAndroid Build Coastguard Worker        if scalars is not None:
56*da0073e9SAndroid Build Coastguard Worker            assert len(inputs) == 3
57*da0073e9SAndroid Build Coastguard Worker            # We need to distribute each scalar to the regular func and it needs
58*da0073e9SAndroid Build Coastguard Worker            # special consideration as it is a keyword only argument to the
59*da0073e9SAndroid Build Coastguard Worker            # regular func. (Strangely, it is not a keyword only argument to the
60*da0073e9SAndroid Build Coastguard Worker            # foreach func)
61*da0073e9SAndroid Build Coastguard Worker            return [
62*da0073e9SAndroid Build Coastguard Worker                self.func(*i, value=scalars[idx], **kwargs)
63*da0073e9SAndroid Build Coastguard Worker                for idx, i in enumerate(zip(*inputs))
64*da0073e9SAndroid Build Coastguard Worker            ]
65*da0073e9SAndroid Build Coastguard Worker        if len(inputs) == 2 and isinstance(inputs[1], (Number, torch.Tensor)):
66*da0073e9SAndroid Build Coastguard Worker            # binary op with tensorlist and scalar.
67*da0073e9SAndroid Build Coastguard Worker            inputs[1] = [inputs[1] for _ in range(len(inputs[0]))]
68*da0073e9SAndroid Build Coastguard Worker        return [self.func(*i, **kwargs) for i in zip(*inputs)]
69*da0073e9SAndroid Build Coastguard Worker
70*da0073e9SAndroid Build Coastguard Worker
71*da0073e9SAndroid Build Coastguard Workerclass ForeachFuncWrapper:
72*da0073e9SAndroid Build Coastguard Worker    def __init__(self, func):
73*da0073e9SAndroid Build Coastguard Worker        self.func = func
74*da0073e9SAndroid Build Coastguard Worker        # Some foreach functions don't have in-place implementations.
75*da0073e9SAndroid Build Coastguard Worker        self.is_inplace = False if func is None else func.__name__.endswith("_")
76*da0073e9SAndroid Build Coastguard Worker
77*da0073e9SAndroid Build Coastguard Worker    def __call__(self, inputs, is_cuda, expect_fastpath, **kwargs):
78*da0073e9SAndroid Build Coastguard Worker        actual = None
79*da0073e9SAndroid Build Coastguard Worker        zero_size = kwargs.pop("zero_size", False)
80*da0073e9SAndroid Build Coastguard Worker        if (
81*da0073e9SAndroid Build Coastguard Worker            is_cuda
82*da0073e9SAndroid Build Coastguard Worker            and torch.autograd.kineto_available()
83*da0073e9SAndroid Build Coastguard Worker            and torch.profiler.ProfilerActivity.CUDA
84*da0073e9SAndroid Build Coastguard Worker            in torch.profiler.supported_activities()
85*da0073e9SAndroid Build Coastguard Worker        ):
86*da0073e9SAndroid Build Coastguard Worker            with torch.profiler.profile() as p:
87*da0073e9SAndroid Build Coastguard Worker                actual = self.func(*inputs, **kwargs)
88*da0073e9SAndroid Build Coastguard Worker            keys = tuple([e.key for e in p.key_averages()])
89*da0073e9SAndroid Build Coastguard Worker            mta_called = any("multi_tensor_apply_kernel" in k for k in keys)
90*da0073e9SAndroid Build Coastguard Worker            assert (
91*da0073e9SAndroid Build Coastguard Worker                mta_called == (expect_fastpath and (not zero_size))
92*da0073e9SAndroid Build Coastguard Worker            ), f"{mta_called=}, {expect_fastpath=}, {zero_size=}, {self.func.__name__=}, {keys=}"
93*da0073e9SAndroid Build Coastguard Worker        else:
94*da0073e9SAndroid Build Coastguard Worker            actual = self.func(*inputs, **kwargs)
95*da0073e9SAndroid Build Coastguard Worker        if self.is_inplace:
96*da0073e9SAndroid Build Coastguard Worker            assert id(inputs[0]) == id(actual)
97*da0073e9SAndroid Build Coastguard Worker        return actual
98*da0073e9SAndroid Build Coastguard Worker
99*da0073e9SAndroid Build Coastguard Worker
100*da0073e9SAndroid Build Coastguard Workerclass InplaceForeachVersionBumpCheck:
101*da0073e9SAndroid Build Coastguard Worker    def __init__(
102*da0073e9SAndroid Build Coastguard Worker        self,
103*da0073e9SAndroid Build Coastguard Worker        testcase: TestCase,
104*da0073e9SAndroid Build Coastguard Worker        tensorlist: "List[torch.Tensor]",  # noqa: F821
105*da0073e9SAndroid Build Coastguard Worker    ) -> None:
106*da0073e9SAndroid Build Coastguard Worker        self._testcase = testcase
107*da0073e9SAndroid Build Coastguard Worker        self._tensorlist = tensorlist
108*da0073e9SAndroid Build Coastguard Worker        self._orig_version_counts = [t._version for t in tensorlist]
109*da0073e9SAndroid Build Coastguard Worker
110*da0073e9SAndroid Build Coastguard Worker    def __enter__(self):
111*da0073e9SAndroid Build Coastguard Worker        pass
112*da0073e9SAndroid Build Coastguard Worker
113*da0073e9SAndroid Build Coastguard Worker    def __exit__(self, exc_type, exc_value, traceback):
114*da0073e9SAndroid Build Coastguard Worker        # note(crcrpar): some methods e.g. `_binary_test` could call the given inplace function multiple times
115*da0073e9SAndroid Build Coastguard Worker        self._testcase.assertGreaterEqual(
116*da0073e9SAndroid Build Coastguard Worker            [t._version for t in self._tensorlist], self._orig_version_counts
117*da0073e9SAndroid Build Coastguard Worker        )
118*da0073e9SAndroid Build Coastguard Worker
119*da0073e9SAndroid Build Coastguard Worker
120*da0073e9SAndroid Build Coastguard Workerdef get_transform_func(num_tensors, dtype, device, is_fastpath):
121*da0073e9SAndroid Build Coastguard Worker    def transform(t):
122*da0073e9SAndroid Build Coastguard Worker        if not torch.is_tensor(t):
123*da0073e9SAndroid Build Coastguard Worker            return t
124*da0073e9SAndroid Build Coastguard Worker        if torch.is_tensor(t) and t.ndim == 0:
125*da0073e9SAndroid Build Coastguard Worker            return t
126*da0073e9SAndroid Build Coastguard Worker        return make_tensor(
127*da0073e9SAndroid Build Coastguard Worker            (num_tensors, num_tensors),
128*da0073e9SAndroid Build Coastguard Worker            dtype=dtype,
129*da0073e9SAndroid Build Coastguard Worker            device=device,
130*da0073e9SAndroid Build Coastguard Worker            requires_grad=True,
131*da0073e9SAndroid Build Coastguard Worker            noncontiguous=not is_fastpath,
132*da0073e9SAndroid Build Coastguard Worker        )
133*da0073e9SAndroid Build Coastguard Worker
134*da0073e9SAndroid Build Coastguard Worker    return transform
135*da0073e9SAndroid Build Coastguard Worker
136*da0073e9SAndroid Build Coastguard Worker
137*da0073e9SAndroid Build Coastguard Worker# note(crcrpar): `zero_size` is `False` unless (dtype, device) == (torch.float32, "cuda")
138*da0073e9SAndroid Build Coastguard Worker# as the pair would go through `multi_tensor_apply_kernel` if inputs are not zero size.
139*da0073e9SAndroid Build Coastguard Worker@unittest.mock.patch.dict(os.environ, {"KINETO_LOG_LEVEL": "5"})
140*da0073e9SAndroid Build Coastguard Workerclass TestForeach(TestCase):
141*da0073e9SAndroid Build Coastguard Worker    @property
142*da0073e9SAndroid Build Coastguard Worker    def is_cuda(self):
143*da0073e9SAndroid Build Coastguard Worker        return self.device_type == "cuda"
144*da0073e9SAndroid Build Coastguard Worker
145*da0073e9SAndroid Build Coastguard Worker    def _get_funcs(self, op):
146*da0073e9SAndroid Build Coastguard Worker        return (
147*da0073e9SAndroid Build Coastguard Worker            ForeachFuncWrapper(op.method_variant),
148*da0073e9SAndroid Build Coastguard Worker            RegularFuncWrapper(op.ref),
149*da0073e9SAndroid Build Coastguard Worker            ForeachFuncWrapper(op.inplace_variant),
150*da0073e9SAndroid Build Coastguard Worker            RegularFuncWrapper(op.ref_inplace),
151*da0073e9SAndroid Build Coastguard Worker        )
152*da0073e9SAndroid Build Coastguard Worker
153*da0073e9SAndroid Build Coastguard Worker    # note(crcrpar): Make sure 0-size tensors are appropriately ignored by `multi_tensor_apply`
154*da0073e9SAndroid Build Coastguard Worker    # which is originally reported in https://github.com/pytorch/pytorch/issues/94865.
155*da0073e9SAndroid Build Coastguard Worker    # rel:
156*da0073e9SAndroid Build Coastguard Worker    #   - https://github.com/pytorch/pytorch/pull/94655
157*da0073e9SAndroid Build Coastguard Worker    #   - https://github.com/pytorch/pytorch/issues/100701
158*da0073e9SAndroid Build Coastguard Worker    #   - https://github.com/pytorch/pytorch/pull/100811
159*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
160*da0073e9SAndroid Build Coastguard Worker    @ops(
161*da0073e9SAndroid Build Coastguard Worker        foreach_unary_op_db
162*da0073e9SAndroid Build Coastguard Worker        + foreach_binary_op_db
163*da0073e9SAndroid Build Coastguard Worker        + foreach_pointwise_op_db
164*da0073e9SAndroid Build Coastguard Worker        + foreach_reduce_op_db
165*da0073e9SAndroid Build Coastguard Worker        + foreach_other_op_db,
166*da0073e9SAndroid Build Coastguard Worker        dtypes=(torch.float32,),
167*da0073e9SAndroid Build Coastguard Worker    )
168*da0073e9SAndroid Build Coastguard Worker    def test_all_zero_size_tensors_do_not_launch_kernel(self, device, dtype, op):
169*da0073e9SAndroid Build Coastguard Worker        wrapped_op, _, inplace_op, _ = self._get_funcs(op)
170*da0073e9SAndroid Build Coastguard Worker
171*da0073e9SAndroid Build Coastguard Worker        for sample in op.sample_zero_size_inputs(device, dtype):
172*da0073e9SAndroid Build Coastguard Worker            if op.method_variant is not None:
173*da0073e9SAndroid Build Coastguard Worker                wrapped_op(
174*da0073e9SAndroid Build Coastguard Worker                    (sample.input, *sample.args),
175*da0073e9SAndroid Build Coastguard Worker                    is_cuda=self.is_cuda,
176*da0073e9SAndroid Build Coastguard Worker                    expect_fastpath=True,
177*da0073e9SAndroid Build Coastguard Worker                    zero_size=True,
178*da0073e9SAndroid Build Coastguard Worker                )
179*da0073e9SAndroid Build Coastguard Worker
180*da0073e9SAndroid Build Coastguard Worker            if op.inplace_variant is not None:
181*da0073e9SAndroid Build Coastguard Worker                with InplaceForeachVersionBumpCheck(self, sample.input):
182*da0073e9SAndroid Build Coastguard Worker                    inplace_op(
183*da0073e9SAndroid Build Coastguard Worker                        (sample.input, *sample.args),
184*da0073e9SAndroid Build Coastguard Worker                        is_cuda=self.is_cuda,
185*da0073e9SAndroid Build Coastguard Worker                        expect_fastpath=True,
186*da0073e9SAndroid Build Coastguard Worker                        zero_size=True,
187*da0073e9SAndroid Build Coastguard Worker                    )
188*da0073e9SAndroid Build Coastguard Worker
189*da0073e9SAndroid Build Coastguard Worker    @skipIfRocmVersionLessThan((6, 0))
190*da0073e9SAndroid Build Coastguard Worker    @ops(
191*da0073e9SAndroid Build Coastguard Worker        foreach_unary_op_db
192*da0073e9SAndroid Build Coastguard Worker        + foreach_binary_op_db
193*da0073e9SAndroid Build Coastguard Worker        + foreach_pointwise_op_db
194*da0073e9SAndroid Build Coastguard Worker        + foreach_reduce_op_db
195*da0073e9SAndroid Build Coastguard Worker        + foreach_other_op_db,
196*da0073e9SAndroid Build Coastguard Worker    )
197*da0073e9SAndroid Build Coastguard Worker    @parametrize(
198*da0073e9SAndroid Build Coastguard Worker        "noncontiguous,inplace",
199*da0073e9SAndroid Build Coastguard Worker        [(False, False), (False, True), (True, False), (True, True)],
200*da0073e9SAndroid Build Coastguard Worker        name_fn=lambda x, y: "{}_{}".format(
201*da0073e9SAndroid Build Coastguard Worker            "fastpath" if not x else "slowpath", "inplace" if y else "outplace"
202*da0073e9SAndroid Build Coastguard Worker        ),
203*da0073e9SAndroid Build Coastguard Worker    )
204*da0073e9SAndroid Build Coastguard Worker    def test_parity(self, device, dtype, op, noncontiguous, inplace):
205*da0073e9SAndroid Build Coastguard Worker        if inplace:
206*da0073e9SAndroid Build Coastguard Worker            _, _, func, ref = self._get_funcs(op)
207*da0073e9SAndroid Build Coastguard Worker        else:
208*da0073e9SAndroid Build Coastguard Worker            func, ref, _, _ = self._get_funcs(op)
209*da0073e9SAndroid Build Coastguard Worker        for sample in op.sample_inputs(
210*da0073e9SAndroid Build Coastguard Worker            device, dtype, noncontiguous=noncontiguous, allow_higher_dtype_scalars=True
211*da0073e9SAndroid Build Coastguard Worker        ):
212*da0073e9SAndroid Build Coastguard Worker            ref_kwargs = sample.kwargs
213*da0073e9SAndroid Build Coastguard Worker            # div promotes ints to floats, so we cannot go on the fastpath there
214*da0073e9SAndroid Build Coastguard Worker            div_slowpath = (
215*da0073e9SAndroid Build Coastguard Worker                dtype in integral_types_and(torch.bool) and op.name == "_foreach_div"
216*da0073e9SAndroid Build Coastguard Worker            )
217*da0073e9SAndroid Build Coastguard Worker            expect_fastpath = not (
218*da0073e9SAndroid Build Coastguard Worker                noncontiguous or sample.disable_fastpath or div_slowpath
219*da0073e9SAndroid Build Coastguard Worker            )
220*da0073e9SAndroid Build Coastguard Worker            ref_input, ctxmgr = sample.input, nullcontext()
221*da0073e9SAndroid Build Coastguard Worker            if inplace:
222*da0073e9SAndroid Build Coastguard Worker                with torch.no_grad():
223*da0073e9SAndroid Build Coastguard Worker                    ref_input = [t.clone().detach() for t in sample.input]
224*da0073e9SAndroid Build Coastguard Worker                ctxmgr = InplaceForeachVersionBumpCheck(self, sample.input)
225*da0073e9SAndroid Build Coastguard Worker            try:
226*da0073e9SAndroid Build Coastguard Worker                with ctxmgr:
227*da0073e9SAndroid Build Coastguard Worker                    actual = func(
228*da0073e9SAndroid Build Coastguard Worker                        [sample.input, *sample.args],
229*da0073e9SAndroid Build Coastguard Worker                        self.is_cuda,
230*da0073e9SAndroid Build Coastguard Worker                        expect_fastpath,
231*da0073e9SAndroid Build Coastguard Worker                        **sample.kwargs,
232*da0073e9SAndroid Build Coastguard Worker                    )
233*da0073e9SAndroid Build Coastguard Worker            except Exception as e:
234*da0073e9SAndroid Build Coastguard Worker                with self.assertRaises(type(e)):
235*da0073e9SAndroid Build Coastguard Worker                    ref([ref_input, *sample.ref_args], **ref_kwargs)
236*da0073e9SAndroid Build Coastguard Worker            else:
237*da0073e9SAndroid Build Coastguard Worker                expected = ref([ref_input, *sample.ref_args], **ref_kwargs)
238*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(expected, actual)
239*da0073e9SAndroid Build Coastguard Worker
240*da0073e9SAndroid Build Coastguard Worker    def _binary_test(
241*da0073e9SAndroid Build Coastguard Worker        self,
242*da0073e9SAndroid Build Coastguard Worker        dtype,
243*da0073e9SAndroid Build Coastguard Worker        op,
244*da0073e9SAndroid Build Coastguard Worker        ref,
245*da0073e9SAndroid Build Coastguard Worker        inputs,
246*da0073e9SAndroid Build Coastguard Worker        is_fastpath,
247*da0073e9SAndroid Build Coastguard Worker        is_inplace,
248*da0073e9SAndroid Build Coastguard Worker        *,
249*da0073e9SAndroid Build Coastguard Worker        alpha,
250*da0073e9SAndroid Build Coastguard Worker        scalar_self_arg: bool,
251*da0073e9SAndroid Build Coastguard Worker    ):
252*da0073e9SAndroid Build Coastguard Worker        ref_inputs = (
253*da0073e9SAndroid Build Coastguard Worker            [[t.clone().detach() for t in inputs[0]], inputs[1]]
254*da0073e9SAndroid Build Coastguard Worker            if is_inplace
255*da0073e9SAndroid Build Coastguard Worker            else inputs
256*da0073e9SAndroid Build Coastguard Worker        )
257*da0073e9SAndroid Build Coastguard Worker        try:
258*da0073e9SAndroid Build Coastguard Worker            with InplaceForeachVersionBumpCheck(
259*da0073e9SAndroid Build Coastguard Worker                self, inputs[0]
260*da0073e9SAndroid Build Coastguard Worker            ) if op.is_inplace else nullcontext():
261*da0073e9SAndroid Build Coastguard Worker                actual = op(inputs, self.is_cuda, is_fastpath)
262*da0073e9SAndroid Build Coastguard Worker        except RuntimeError as e:
263*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(type(e), re.escape(str(e).splitlines()[0])):
264*da0073e9SAndroid Build Coastguard Worker                if not scalar_self_arg:
265*da0073e9SAndroid Build Coastguard Worker                    ref(ref_inputs)
266*da0073e9SAndroid Build Coastguard Worker                else:
267*da0073e9SAndroid Build Coastguard Worker                    [ref.func(ref_inputs[0], t) for t in ref_inputs[1]]
268*da0073e9SAndroid Build Coastguard Worker        else:
269*da0073e9SAndroid Build Coastguard Worker            expected = (
270*da0073e9SAndroid Build Coastguard Worker                ref(ref_inputs)
271*da0073e9SAndroid Build Coastguard Worker                if not scalar_self_arg
272*da0073e9SAndroid Build Coastguard Worker                else [ref.func(ref_inputs[0], t) for t in ref_inputs[1]]
273*da0073e9SAndroid Build Coastguard Worker            )
274*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(actual, expected)
275*da0073e9SAndroid Build Coastguard Worker        if alpha is not None and not scalar_self_arg:
276*da0073e9SAndroid Build Coastguard Worker            kwargs = {"alpha": alpha}
277*da0073e9SAndroid Build Coastguard Worker            ref_inputs = inputs
278*da0073e9SAndroid Build Coastguard Worker            try:
279*da0073e9SAndroid Build Coastguard Worker                op_kwargs = {}
280*da0073e9SAndroid Build Coastguard Worker                op_kwargs.update(kwargs)
281*da0073e9SAndroid Build Coastguard Worker                with InplaceForeachVersionBumpCheck(
282*da0073e9SAndroid Build Coastguard Worker                    self, inputs[0]
283*da0073e9SAndroid Build Coastguard Worker                ) if op.is_inplace else nullcontext():
284*da0073e9SAndroid Build Coastguard Worker                    actual = op(inputs, self.is_cuda, is_fastpath, **op_kwargs)
285*da0073e9SAndroid Build Coastguard Worker            except RuntimeError as e:
286*da0073e9SAndroid Build Coastguard Worker                with self.assertRaisesRegex(type(e), re.escape(str(e).splitlines()[0])):
287*da0073e9SAndroid Build Coastguard Worker                    ref(ref_inputs, **kwargs)
288*da0073e9SAndroid Build Coastguard Worker            else:
289*da0073e9SAndroid Build Coastguard Worker                expected = ref(ref_inputs, **kwargs)
290*da0073e9SAndroid Build Coastguard Worker                if dtype in (torch.float16, torch.bfloat16) and TEST_WITH_ROCM:
291*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(
292*da0073e9SAndroid Build Coastguard Worker                        expected, actual, atol=1.0e-3, rtol=default_tolerances(dtype)[0]
293*da0073e9SAndroid Build Coastguard Worker                    )
294*da0073e9SAndroid Build Coastguard Worker                else:
295*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(expected, actual)
296*da0073e9SAndroid Build Coastguard Worker
297*da0073e9SAndroid Build Coastguard Worker    @ops(filter(lambda op: op.supports_scalar_self_arg, foreach_binary_op_db))
298*da0073e9SAndroid Build Coastguard Worker    @parametrize("is_fastpath", (True, False))
299*da0073e9SAndroid Build Coastguard Worker    def test_binary_op_with_scalar_self_support(self, device, dtype, op, is_fastpath):
300*da0073e9SAndroid Build Coastguard Worker        def clone(arg):
301*da0073e9SAndroid Build Coastguard Worker            if isinstance(arg, (list, tuple)):
302*da0073e9SAndroid Build Coastguard Worker                return [clone(a) for a in arg]
303*da0073e9SAndroid Build Coastguard Worker            if torch.is_tensor(arg):
304*da0073e9SAndroid Build Coastguard Worker                return arg.clone().detach().requires_grad_()
305*da0073e9SAndroid Build Coastguard Worker            else:
306*da0073e9SAndroid Build Coastguard Worker                return arg
307*da0073e9SAndroid Build Coastguard Worker
308*da0073e9SAndroid Build Coastguard Worker        scalar_self_arg_test_complete = False
309*da0073e9SAndroid Build Coastguard Worker        for i, sample in enumerate(
310*da0073e9SAndroid Build Coastguard Worker            op.sample_inputs(
311*da0073e9SAndroid Build Coastguard Worker                device,
312*da0073e9SAndroid Build Coastguard Worker                dtype,
313*da0073e9SAndroid Build Coastguard Worker                noncontiguous=not is_fastpath,
314*da0073e9SAndroid Build Coastguard Worker                allow_higher_dtype_scalars=True,
315*da0073e9SAndroid Build Coastguard Worker            )
316*da0073e9SAndroid Build Coastguard Worker        ):
317*da0073e9SAndroid Build Coastguard Worker            (rhs_arg,) = sample.args
318*da0073e9SAndroid Build Coastguard Worker            kwargs = {} or sample.kwargs
319*da0073e9SAndroid Build Coastguard Worker            alpha = kwargs.pop("alpha", None)
320*da0073e9SAndroid Build Coastguard Worker            wrapped_op, ref, inplace_op, inplace_ref = self._get_funcs(op)
321*da0073e9SAndroid Build Coastguard Worker            if isinstance(rhs_arg, Number) and not scalar_self_arg_test_complete:
322*da0073e9SAndroid Build Coastguard Worker                scalar_self_arg_test_complete = True
323*da0073e9SAndroid Build Coastguard Worker                self._binary_test(
324*da0073e9SAndroid Build Coastguard Worker                    dtype,
325*da0073e9SAndroid Build Coastguard Worker                    wrapped_op,
326*da0073e9SAndroid Build Coastguard Worker                    ref,
327*da0073e9SAndroid Build Coastguard Worker                    [rhs_arg, sample.input],
328*da0073e9SAndroid Build Coastguard Worker                    is_fastpath,
329*da0073e9SAndroid Build Coastguard Worker                    False,
330*da0073e9SAndroid Build Coastguard Worker                    alpha=alpha,
331*da0073e9SAndroid Build Coastguard Worker                    scalar_self_arg=True,
332*da0073e9SAndroid Build Coastguard Worker                )
333*da0073e9SAndroid Build Coastguard Worker                if op.supports_autograd and dtype == torch.float32:
334*da0073e9SAndroid Build Coastguard Worker                    transformed_sample = sample.transform(
335*da0073e9SAndroid Build Coastguard Worker                        get_transform_func(
336*da0073e9SAndroid Build Coastguard Worker                            len(sample.input), dtype, device, is_fastpath
337*da0073e9SAndroid Build Coastguard Worker                        )
338*da0073e9SAndroid Build Coastguard Worker                    )
339*da0073e9SAndroid Build Coastguard Worker                    tensors = transformed_sample.input
340*da0073e9SAndroid Build Coastguard Worker                    (rhs_arg,) = transformed_sample.args
341*da0073e9SAndroid Build Coastguard Worker                    ref_tensors, ref_rhs_arg = clone(tensors), clone(rhs_arg)
342*da0073e9SAndroid Build Coastguard Worker                    sum(
343*da0073e9SAndroid Build Coastguard Worker                        wrapped_op(
344*da0073e9SAndroid Build Coastguard Worker                            [rhs_arg, tensors], is_cuda=False, expect_fastpath=False
345*da0073e9SAndroid Build Coastguard Worker                        )
346*da0073e9SAndroid Build Coastguard Worker                    ).mean().backward()
347*da0073e9SAndroid Build Coastguard Worker                    sum(ref.func(ref_rhs_arg, t) for t in ref_tensors).mean().backward()
348*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(
349*da0073e9SAndroid Build Coastguard Worker                        [t.grad for t in tensors], [t.grad for t in ref_tensors]
350*da0073e9SAndroid Build Coastguard Worker                    )
351*da0073e9SAndroid Build Coastguard Worker
352*da0073e9SAndroid Build Coastguard Worker    @ops(foreach_pointwise_op_db)
353*da0073e9SAndroid Build Coastguard Worker    @parametrize("is_fastpath", (True, False))
354*da0073e9SAndroid Build Coastguard Worker    def test_pointwise_op_with_tensor_of_scalarlist_overload(
355*da0073e9SAndroid Build Coastguard Worker        self, device, dtype, op, is_fastpath
356*da0073e9SAndroid Build Coastguard Worker    ):
357*da0073e9SAndroid Build Coastguard Worker        for sample in op.sample_inputs(
358*da0073e9SAndroid Build Coastguard Worker            device,
359*da0073e9SAndroid Build Coastguard Worker            dtype,
360*da0073e9SAndroid Build Coastguard Worker            noncontiguous=not is_fastpath,
361*da0073e9SAndroid Build Coastguard Worker            allow_higher_dtype_scalars=True,
362*da0073e9SAndroid Build Coastguard Worker        ):
363*da0073e9SAndroid Build Coastguard Worker            assert isinstance(sample.args, tuple)
364*da0073e9SAndroid Build Coastguard Worker            assert len(sample.args) == 2
365*da0073e9SAndroid Build Coastguard Worker            inputs = [sample.input, *sample.args]
366*da0073e9SAndroid Build Coastguard Worker            kwargs = sample.kwargs.copy()
367*da0073e9SAndroid Build Coastguard Worker            disable_fastpath = sample.disable_fastpath and is_fastpath
368*da0073e9SAndroid Build Coastguard Worker            wrapped_op, ref, inplace_op, inplace_ref = self._get_funcs(op)
369*da0073e9SAndroid Build Coastguard Worker            scalars = kwargs.pop("scalars", None)
370*da0073e9SAndroid Build Coastguard Worker
371*da0073e9SAndroid Build Coastguard Worker            if is_fastpath and scalars:
372*da0073e9SAndroid Build Coastguard Worker                sample = sample.transform(
373*da0073e9SAndroid Build Coastguard Worker                    lambda t: t.clone().detach() if torch.is_tensor(t) else t
374*da0073e9SAndroid Build Coastguard Worker                )
375*da0073e9SAndroid Build Coastguard Worker                inputs = [sample.input, *sample.args]
376*da0073e9SAndroid Build Coastguard Worker                tensor_values = torch.tensor(scalars)
377*da0073e9SAndroid Build Coastguard Worker                # 1D Tensor of scalars
378*da0073e9SAndroid Build Coastguard Worker                for is_inplace, op_, ref_ in (
379*da0073e9SAndroid Build Coastguard Worker                    (False, wrapped_op, ref),
380*da0073e9SAndroid Build Coastguard Worker                    (True, inplace_op, inplace_ref),
381*da0073e9SAndroid Build Coastguard Worker                ):
382*da0073e9SAndroid Build Coastguard Worker                    self._pointwise_test(
383*da0073e9SAndroid Build Coastguard Worker                        op_,
384*da0073e9SAndroid Build Coastguard Worker                        ref_,
385*da0073e9SAndroid Build Coastguard Worker                        inputs,
386*da0073e9SAndroid Build Coastguard Worker                        is_fastpath and not disable_fastpath,
387*da0073e9SAndroid Build Coastguard Worker                        is_inplace,
388*da0073e9SAndroid Build Coastguard Worker                        scalars=tensor_values,
389*da0073e9SAndroid Build Coastguard Worker                        **kwargs,
390*da0073e9SAndroid Build Coastguard Worker                    )
391*da0073e9SAndroid Build Coastguard Worker                    self._pointwise_test(
392*da0073e9SAndroid Build Coastguard Worker                        op_,
393*da0073e9SAndroid Build Coastguard Worker                        ref_,
394*da0073e9SAndroid Build Coastguard Worker                        inputs,
395*da0073e9SAndroid Build Coastguard Worker                        is_fastpath and not disable_fastpath,
396*da0073e9SAndroid Build Coastguard Worker                        is_inplace,
397*da0073e9SAndroid Build Coastguard Worker                        scalars=tensor_values[0],
398*da0073e9SAndroid Build Coastguard Worker                        custom_values_err="Expected packed scalar Tensor to be of dimension 1. Got 0 instead.",
399*da0073e9SAndroid Build Coastguard Worker                        **kwargs,
400*da0073e9SAndroid Build Coastguard Worker                    )
401*da0073e9SAndroid Build Coastguard Worker                    if self.is_cuda:
402*da0073e9SAndroid Build Coastguard Worker                        self._pointwise_test(
403*da0073e9SAndroid Build Coastguard Worker                            op_,
404*da0073e9SAndroid Build Coastguard Worker                            ref_,
405*da0073e9SAndroid Build Coastguard Worker                            inputs,
406*da0073e9SAndroid Build Coastguard Worker                            is_fastpath and not disable_fastpath,
407*da0073e9SAndroid Build Coastguard Worker                            is_inplace,
408*da0073e9SAndroid Build Coastguard Worker                            scalars=tensor_values.cuda(),
409*da0073e9SAndroid Build Coastguard Worker                            custom_values_err="Expected scalars to be on CPU, got cuda:0 instead.",
410*da0073e9SAndroid Build Coastguard Worker                            **kwargs,
411*da0073e9SAndroid Build Coastguard Worker                        )
412*da0073e9SAndroid Build Coastguard Worker                    self._pointwise_test(
413*da0073e9SAndroid Build Coastguard Worker                        op_,
414*da0073e9SAndroid Build Coastguard Worker                        ref_,
415*da0073e9SAndroid Build Coastguard Worker                        inputs,
416*da0073e9SAndroid Build Coastguard Worker                        is_fastpath and not disable_fastpath,
417*da0073e9SAndroid Build Coastguard Worker                        is_inplace,
418*da0073e9SAndroid Build Coastguard Worker                        scalars=tensor_values[:2],
419*da0073e9SAndroid Build Coastguard Worker                        custom_values_err=f"Expected length of scalars to match input of length {len(scalars)} but got 2 instead.",
420*da0073e9SAndroid Build Coastguard Worker                        **kwargs,
421*da0073e9SAndroid Build Coastguard Worker                    )
422*da0073e9SAndroid Build Coastguard Worker                    self._pointwise_test(
423*da0073e9SAndroid Build Coastguard Worker                        op_,
424*da0073e9SAndroid Build Coastguard Worker                        ref_,
425*da0073e9SAndroid Build Coastguard Worker                        inputs,
426*da0073e9SAndroid Build Coastguard Worker                        is_fastpath and not disable_fastpath,
427*da0073e9SAndroid Build Coastguard Worker                        is_inplace,
428*da0073e9SAndroid Build Coastguard Worker                        scalars=torch.tensor([[0, 1], [2, 3]])[:, 1],
429*da0073e9SAndroid Build Coastguard Worker                        custom_values_err="Expected scalars to be contiguous.",
430*da0073e9SAndroid Build Coastguard Worker                        **kwargs,
431*da0073e9SAndroid Build Coastguard Worker                    )
432*da0073e9SAndroid Build Coastguard Worker
433*da0073e9SAndroid Build Coastguard Worker            # Tests of implicit broadcasting
434*da0073e9SAndroid Build Coastguard Worker            N = len(sample.input)
435*da0073e9SAndroid Build Coastguard Worker            inputs = [
436*da0073e9SAndroid Build Coastguard Worker                [
437*da0073e9SAndroid Build Coastguard Worker                    make_tensor(
438*da0073e9SAndroid Build Coastguard Worker                        (N, N),
439*da0073e9SAndroid Build Coastguard Worker                        device=device,
440*da0073e9SAndroid Build Coastguard Worker                        dtype=dtype,
441*da0073e9SAndroid Build Coastguard Worker                        noncontiguous=not is_fastpath,
442*da0073e9SAndroid Build Coastguard Worker                    )
443*da0073e9SAndroid Build Coastguard Worker                    for _ in range(N)
444*da0073e9SAndroid Build Coastguard Worker                ],
445*da0073e9SAndroid Build Coastguard Worker                [
446*da0073e9SAndroid Build Coastguard Worker                    make_tensor(
447*da0073e9SAndroid Build Coastguard Worker                        (N - i, 1),
448*da0073e9SAndroid Build Coastguard Worker                        device=device,
449*da0073e9SAndroid Build Coastguard Worker                        dtype=dtype,
450*da0073e9SAndroid Build Coastguard Worker                        noncontiguous=not is_fastpath,
451*da0073e9SAndroid Build Coastguard Worker                    )
452*da0073e9SAndroid Build Coastguard Worker                    for i in range(N)
453*da0073e9SAndroid Build Coastguard Worker                ],
454*da0073e9SAndroid Build Coastguard Worker                [
455*da0073e9SAndroid Build Coastguard Worker                    make_tensor(
456*da0073e9SAndroid Build Coastguard Worker                        (1, N - i),
457*da0073e9SAndroid Build Coastguard Worker                        device=device,
458*da0073e9SAndroid Build Coastguard Worker                        dtype=dtype,
459*da0073e9SAndroid Build Coastguard Worker                        noncontiguous=not is_fastpath,
460*da0073e9SAndroid Build Coastguard Worker                    )
461*da0073e9SAndroid Build Coastguard Worker                    for i in range(N)
462*da0073e9SAndroid Build Coastguard Worker                ],
463*da0073e9SAndroid Build Coastguard Worker            ]
464*da0073e9SAndroid Build Coastguard Worker            self._pointwise_test(
465*da0073e9SAndroid Build Coastguard Worker                wrapped_op,
466*da0073e9SAndroid Build Coastguard Worker                ref,
467*da0073e9SAndroid Build Coastguard Worker                inputs,
468*da0073e9SAndroid Build Coastguard Worker                is_fastpath and disable_fastpath,
469*da0073e9SAndroid Build Coastguard Worker                is_inplace=False,
470*da0073e9SAndroid Build Coastguard Worker                scalars=scalars,
471*da0073e9SAndroid Build Coastguard Worker                **kwargs,
472*da0073e9SAndroid Build Coastguard Worker            )
473*da0073e9SAndroid Build Coastguard Worker            self._pointwise_test(
474*da0073e9SAndroid Build Coastguard Worker                inplace_op,
475*da0073e9SAndroid Build Coastguard Worker                inplace_ref,
476*da0073e9SAndroid Build Coastguard Worker                inputs,
477*da0073e9SAndroid Build Coastguard Worker                is_fastpath and disable_fastpath,
478*da0073e9SAndroid Build Coastguard Worker                is_inplace=True,
479*da0073e9SAndroid Build Coastguard Worker                scalars=scalars,
480*da0073e9SAndroid Build Coastguard Worker                **kwargs,
481*da0073e9SAndroid Build Coastguard Worker            )
482*da0073e9SAndroid Build Coastguard Worker
483*da0073e9SAndroid Build Coastguard Worker    def _pointwise_test(
484*da0073e9SAndroid Build Coastguard Worker        self,
485*da0073e9SAndroid Build Coastguard Worker        op,
486*da0073e9SAndroid Build Coastguard Worker        ref,
487*da0073e9SAndroid Build Coastguard Worker        inputs,
488*da0073e9SAndroid Build Coastguard Worker        is_fastpath,
489*da0073e9SAndroid Build Coastguard Worker        is_inplace,
490*da0073e9SAndroid Build Coastguard Worker        *,
491*da0073e9SAndroid Build Coastguard Worker        scalars=None,
492*da0073e9SAndroid Build Coastguard Worker        custom_values_err=None,
493*da0073e9SAndroid Build Coastguard Worker        **kwargs,
494*da0073e9SAndroid Build Coastguard Worker    ):
495*da0073e9SAndroid Build Coastguard Worker        ref_inputs = (
496*da0073e9SAndroid Build Coastguard Worker            [[t.clone().detach() for t in inputs[0]], inputs[1], inputs[2]]
497*da0073e9SAndroid Build Coastguard Worker            if is_inplace
498*da0073e9SAndroid Build Coastguard Worker            else inputs
499*da0073e9SAndroid Build Coastguard Worker        )
500*da0073e9SAndroid Build Coastguard Worker        try:
501*da0073e9SAndroid Build Coastguard Worker            with (
502*da0073e9SAndroid Build Coastguard Worker                InplaceForeachVersionBumpCheck(self, inputs[0])
503*da0073e9SAndroid Build Coastguard Worker                if is_inplace
504*da0073e9SAndroid Build Coastguard Worker                else nullcontext()
505*da0073e9SAndroid Build Coastguard Worker            ):
506*da0073e9SAndroid Build Coastguard Worker                actual = op(inputs, self.is_cuda, is_fastpath, **kwargs)
507*da0073e9SAndroid Build Coastguard Worker        except RuntimeError as e:
508*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(type(e), re.escape(str(e).splitlines()[0])):
509*da0073e9SAndroid Build Coastguard Worker                ref(ref_inputs, **kwargs)
510*da0073e9SAndroid Build Coastguard Worker        else:
511*da0073e9SAndroid Build Coastguard Worker            expected = ref(ref_inputs, **kwargs)
512*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(expected, actual)
513*da0073e9SAndroid Build Coastguard Worker        if scalars is not None:
514*da0073e9SAndroid Build Coastguard Worker            kwargs = kwargs.copy()
515*da0073e9SAndroid Build Coastguard Worker            kwargs["scalars"] = scalars
516*da0073e9SAndroid Build Coastguard Worker            try:
517*da0073e9SAndroid Build Coastguard Worker                actual = op(inputs, self.is_cuda, is_fastpath, **kwargs)
518*da0073e9SAndroid Build Coastguard Worker            except RuntimeError as e:
519*da0073e9SAndroid Build Coastguard Worker                # Match with error messages from regular non-foreach reference if no
520*da0073e9SAndroid Build Coastguard Worker                # custom error message was provided.
521*da0073e9SAndroid Build Coastguard Worker                if custom_values_err is None:
522*da0073e9SAndroid Build Coastguard Worker                    with self.assertRaisesRegex(
523*da0073e9SAndroid Build Coastguard Worker                        type(e), re.escape(str(e).splitlines()[0])
524*da0073e9SAndroid Build Coastguard Worker                    ):
525*da0073e9SAndroid Build Coastguard Worker                        ref(ref_inputs, **kwargs)
526*da0073e9SAndroid Build Coastguard Worker                else:
527*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(re.escape(str(e)), re.escape(custom_values_err))
528*da0073e9SAndroid Build Coastguard Worker            else:
529*da0073e9SAndroid Build Coastguard Worker                expected = ref(ref_inputs, **kwargs)
530*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(expected, actual)
531*da0073e9SAndroid Build Coastguard Worker
532*da0073e9SAndroid Build Coastguard Worker    @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16))
533*da0073e9SAndroid Build Coastguard Worker    def test_add_scalar_with_empty_list_and_empty_tensor(self, device, dtype):
534*da0073e9SAndroid Build Coastguard Worker        # TODO: enable empty list case
535*da0073e9SAndroid Build Coastguard Worker        for tensors in [
536*da0073e9SAndroid Build Coastguard Worker            [torch.randn([0], device=device, dtype=dtype)],
537*da0073e9SAndroid Build Coastguard Worker            [torch.empty_strided((0, 1), (0, 0), dtype=dtype, device=device)],
538*da0073e9SAndroid Build Coastguard Worker        ]:
539*da0073e9SAndroid Build Coastguard Worker            res = torch._foreach_add(tensors, 1)
540*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(res, tensors)
541*da0073e9SAndroid Build Coastguard Worker
542*da0073e9SAndroid Build Coastguard Worker            torch._foreach_add_(tensors, 1)
543*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(res, tensors)
544*da0073e9SAndroid Build Coastguard Worker
545*da0073e9SAndroid Build Coastguard Worker            # Regression test for https://github.com/pytorch/pytorch/issues/113156
546*da0073e9SAndroid Build Coastguard Worker            torch._foreach_mul_(tensors, 1)
547*da0073e9SAndroid Build Coastguard Worker
548*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
549*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float32)
550*da0073e9SAndroid Build Coastguard Worker    def test_foreach_check_stride_ignore_dims_of_one(self, device, dtype):
551*da0073e9SAndroid Build Coastguard Worker        # default tensor stride is (9, 9, 3, 1).
552*da0073e9SAndroid Build Coastguard Worker        tensor = torch.ones((2, 1, 3, 3), device=device, dtype=dtype)
553*da0073e9SAndroid Build Coastguard Worker        strided_tensor = torch.ones(
554*da0073e9SAndroid Build Coastguard Worker            (2, 1, 3, 3), device=device, dtype=dtype
555*da0073e9SAndroid Build Coastguard Worker        ).as_strided((2, 1, 3, 3), (9, 1, 3, 1))
556*da0073e9SAndroid Build Coastguard Worker        left_inputs = [tensor, strided_tensor]
557*da0073e9SAndroid Build Coastguard Worker        right_inputs = [strided_tensor, tensor]
558*da0073e9SAndroid Build Coastguard Worker        compare_result = tensor + strided_tensor
559*da0073e9SAndroid Build Coastguard Worker        foreach_add_check_ = ForeachFuncWrapper(torch._foreach_add)
560*da0073e9SAndroid Build Coastguard Worker        out = foreach_add_check_(
561*da0073e9SAndroid Build Coastguard Worker            (left_inputs, right_inputs), is_cuda=True, expect_fastpath=True
562*da0073e9SAndroid Build Coastguard Worker        )
563*da0073e9SAndroid Build Coastguard Worker        for res in out:
564*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(res, compare_result)
565*da0073e9SAndroid Build Coastguard Worker
566*da0073e9SAndroid Build Coastguard Worker    @ops(
567*da0073e9SAndroid Build Coastguard Worker        filter(lambda op: op.supports_out, foreach_binary_op_db),
568*da0073e9SAndroid Build Coastguard Worker        dtypes=OpDTypes.supported,
569*da0073e9SAndroid Build Coastguard Worker    )
570*da0073e9SAndroid Build Coastguard Worker    def test_binary_op_scalar_with_overlapping_tensors(self, device, dtype, op):
571*da0073e9SAndroid Build Coastguard Worker        foreach_op, ref = op.method_variant, op.ref
572*da0073e9SAndroid Build Coastguard Worker        tensors = [torch.ones(1, 1, device=device, dtype=dtype).expand(2, 1, 3)]
573*da0073e9SAndroid Build Coastguard Worker
574*da0073e9SAndroid Build Coastguard Worker        if ref == torch.sub and dtype == torch.bool:
575*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, re.escape(_BOOL_SUB_ERR_MSG)):
576*da0073e9SAndroid Build Coastguard Worker                [ref(t, 1) for t in tensors]
577*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, re.escape(_BOOL_SUB_ERR_MSG)):
578*da0073e9SAndroid Build Coastguard Worker                foreach_op(tensors, 1)
579*da0073e9SAndroid Build Coastguard Worker            return
580*da0073e9SAndroid Build Coastguard Worker
581*da0073e9SAndroid Build Coastguard Worker        expected = [ref(t, 1) for t in tensors]
582*da0073e9SAndroid Build Coastguard Worker        res = foreach_op(tensors, 1)
583*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res, expected)
584*da0073e9SAndroid Build Coastguard Worker
585*da0073e9SAndroid Build Coastguard Worker    @ops(
586*da0073e9SAndroid Build Coastguard Worker        filter(lambda op: op.supports_out, foreach_binary_op_db),
587*da0073e9SAndroid Build Coastguard Worker        allowed_dtypes=[torch.float],
588*da0073e9SAndroid Build Coastguard Worker    )
589*da0073e9SAndroid Build Coastguard Worker    def test_binary_op_scalar_with_different_tensor_dtypes(self, device, dtype, op):
590*da0073e9SAndroid Build Coastguard Worker        foreach_op = op.method_variant
591*da0073e9SAndroid Build Coastguard Worker        tensors = [
592*da0073e9SAndroid Build Coastguard Worker            torch.tensor([1.1], dtype=torch.float, device=device),
593*da0073e9SAndroid Build Coastguard Worker            torch.tensor([1], dtype=torch.long, device=device),
594*da0073e9SAndroid Build Coastguard Worker        ]
595*da0073e9SAndroid Build Coastguard Worker        runtime_error = None
596*da0073e9SAndroid Build Coastguard Worker        try:
597*da0073e9SAndroid Build Coastguard Worker            foreach_op(tensors, 1)
598*da0073e9SAndroid Build Coastguard Worker        except RuntimeError as e:
599*da0073e9SAndroid Build Coastguard Worker            runtime_error = e
600*da0073e9SAndroid Build Coastguard Worker        self.assertIsNone(runtime_error)
601*da0073e9SAndroid Build Coastguard Worker
602*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("Different error msgs, TODO")
603*da0073e9SAndroid Build Coastguard Worker    @ops(
604*da0073e9SAndroid Build Coastguard Worker        filter(lambda op: op.supports_out, foreach_binary_op_db),
605*da0073e9SAndroid Build Coastguard Worker        dtypes=OpDTypes.supported,
606*da0073e9SAndroid Build Coastguard Worker    )
607*da0073e9SAndroid Build Coastguard Worker    def test_binary_op_list_error_cases(self, device, dtype, op):
608*da0073e9SAndroid Build Coastguard Worker        foreach_op, foreach_op_, ref, ref_ = (
609*da0073e9SAndroid Build Coastguard Worker            op.method_variant,
610*da0073e9SAndroid Build Coastguard Worker            op.inplace_variant,
611*da0073e9SAndroid Build Coastguard Worker            op.ref,
612*da0073e9SAndroid Build Coastguard Worker            op.ref_inplace,
613*da0073e9SAndroid Build Coastguard Worker        )
614*da0073e9SAndroid Build Coastguard Worker        tensors1 = []
615*da0073e9SAndroid Build Coastguard Worker        tensors2 = []
616*da0073e9SAndroid Build Coastguard Worker        ops_to_test = [foreach_op, foreach_op_]
617*da0073e9SAndroid Build Coastguard Worker
618*da0073e9SAndroid Build Coastguard Worker        # Empty lists
619*da0073e9SAndroid Build Coastguard Worker        for fop in ops_to_test:
620*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(
621*da0073e9SAndroid Build Coastguard Worker                RuntimeError, "Tensor list must have at least one tensor."
622*da0073e9SAndroid Build Coastguard Worker            ):
623*da0073e9SAndroid Build Coastguard Worker                fop(tensors1, tensors2)
624*da0073e9SAndroid Build Coastguard Worker
625*da0073e9SAndroid Build Coastguard Worker        # One empty list
626*da0073e9SAndroid Build Coastguard Worker        tensors1.append(torch.tensor([1], device=device, dtype=dtype))
627*da0073e9SAndroid Build Coastguard Worker        for fop in ops_to_test:
628*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(
629*da0073e9SAndroid Build Coastguard Worker                RuntimeError,
630*da0073e9SAndroid Build Coastguard Worker                "Tensor list must have same number of elements as scalar list.",
631*da0073e9SAndroid Build Coastguard Worker            ):
632*da0073e9SAndroid Build Coastguard Worker                fop(tensors1, tensors2)
633*da0073e9SAndroid Build Coastguard Worker
634*da0073e9SAndroid Build Coastguard Worker        # Lists have different amount of tensors
635*da0073e9SAndroid Build Coastguard Worker        tensors2.append(torch.tensor([1], device=device))
636*da0073e9SAndroid Build Coastguard Worker        tensors2.append(torch.tensor([1], device=device))
637*da0073e9SAndroid Build Coastguard Worker        for fop in ops_to_test:
638*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(
639*da0073e9SAndroid Build Coastguard Worker                RuntimeError,
640*da0073e9SAndroid Build Coastguard Worker                "Tensor lists must have the same number of tensors, got 1 and 2",
641*da0073e9SAndroid Build Coastguard Worker            ):
642*da0073e9SAndroid Build Coastguard Worker                fop(tensors1, tensors2)
643*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(
644*da0073e9SAndroid Build Coastguard Worker                RuntimeError,
645*da0073e9SAndroid Build Coastguard Worker                "Tensor lists must have the same number of tensors, got 2 and 1",
646*da0073e9SAndroid Build Coastguard Worker            ):
647*da0073e9SAndroid Build Coastguard Worker                fop(tensors2, tensors1)
648*da0073e9SAndroid Build Coastguard Worker
649*da0073e9SAndroid Build Coastguard Worker        # Corresponding tensors with different sizes that aren't compatible with broadcast
650*da0073e9SAndroid Build Coastguard Worker        # If sizes are different then foreach chooses slow path, thus error messages are expected
651*da0073e9SAndroid Build Coastguard Worker        # to be the same as torch regular function.
652*da0073e9SAndroid Build Coastguard Worker        tensors1 = [torch.zeros(10, 10, device=device, dtype=dtype) for _ in range(10)]
653*da0073e9SAndroid Build Coastguard Worker        tensors2 = [torch.ones(11, 11, device=device, dtype=dtype) for _ in range(10)]
654*da0073e9SAndroid Build Coastguard Worker
655*da0073e9SAndroid Build Coastguard Worker        if dtype == torch.bool and foreach_op == torch._foreach_sub:
656*da0073e9SAndroid Build Coastguard Worker            for fop in ops_to_test:
657*da0073e9SAndroid Build Coastguard Worker                with self.assertRaisesRegex(RuntimeError, re.escape(_BOOL_SUB_ERR_MSG)):
658*da0073e9SAndroid Build Coastguard Worker                    fop(tensors1, tensors2)
659*da0073e9SAndroid Build Coastguard Worker            return
660*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
661*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
662*da0073e9SAndroid Build Coastguard Worker            r"The size of tensor a \(10\) must match the size of tensor b \(11\) at non-singleton dimension 1",
663*da0073e9SAndroid Build Coastguard Worker        ):
664*da0073e9SAndroid Build Coastguard Worker            foreach_op(tensors1, tensors2)
665*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
666*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
667*da0073e9SAndroid Build Coastguard Worker            r"The size of tensor a \(10\) must match the size of tensor b \(11\) at non-singleton dimension 1",
668*da0073e9SAndroid Build Coastguard Worker        ):
669*da0073e9SAndroid Build Coastguard Worker            foreach_op_(tensors1, tensors2)
670*da0073e9SAndroid Build Coastguard Worker
671*da0073e9SAndroid Build Coastguard Worker        # different devices
672*da0073e9SAndroid Build Coastguard Worker        if self.device_type == "cuda" and torch.cuda.device_count() > 1:
673*da0073e9SAndroid Build Coastguard Worker            tensor1 = torch.zeros(10, 10, device="cuda:0", dtype=dtype)
674*da0073e9SAndroid Build Coastguard Worker            tensor2 = torch.ones(10, 10, device="cuda:1", dtype=dtype)
675*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(
676*da0073e9SAndroid Build Coastguard Worker                RuntimeError, "Expected all tensors to be on the same device"
677*da0073e9SAndroid Build Coastguard Worker            ):
678*da0073e9SAndroid Build Coastguard Worker                foreach_op([tensor1], [tensor2])
679*da0073e9SAndroid Build Coastguard Worker            if (
680*da0073e9SAndroid Build Coastguard Worker                dtype in integral_types_and(torch.bool)
681*da0073e9SAndroid Build Coastguard Worker                and foreach_op == torch._foreach_div
682*da0073e9SAndroid Build Coastguard Worker            ):
683*da0073e9SAndroid Build Coastguard Worker                with self.assertRaisesRegex(RuntimeError, "result type"):
684*da0073e9SAndroid Build Coastguard Worker                    foreach_op_([tensor1], [tensor2])
685*da0073e9SAndroid Build Coastguard Worker            else:
686*da0073e9SAndroid Build Coastguard Worker                with self.assertRaisesRegex(
687*da0073e9SAndroid Build Coastguard Worker                    RuntimeError, "Expected all tensors to be on the same device"
688*da0073e9SAndroid Build Coastguard Worker                ):
689*da0073e9SAndroid Build Coastguard Worker                    foreach_op_([tensor1], [tensor2])
690*da0073e9SAndroid Build Coastguard Worker
691*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not torch.cuda.is_available(), "CUDA not found")
692*da0073e9SAndroid Build Coastguard Worker    @ops(
693*da0073e9SAndroid Build Coastguard Worker        filter(lambda op: op.supports_out, foreach_binary_op_db),
694*da0073e9SAndroid Build Coastguard Worker        dtypes=OpDTypes.supported,
695*da0073e9SAndroid Build Coastguard Worker    )
696*da0073e9SAndroid Build Coastguard Worker    def test_binary_op_list_slow_path(self, device, dtype, op):
697*da0073e9SAndroid Build Coastguard Worker        foreach_op, native_op, foreach_op_, native_op_ = self._get_funcs(op)
698*da0073e9SAndroid Build Coastguard Worker        # 0-strides
699*da0073e9SAndroid Build Coastguard Worker        tensor1 = make_tensor((10, 10), dtype=dtype, device=device)
700*da0073e9SAndroid Build Coastguard Worker        tensor2 = make_tensor((1,), device=device, dtype=dtype).expand_as(tensor1)
701*da0073e9SAndroid Build Coastguard Worker        inputs = ([tensor1], [tensor2])
702*da0073e9SAndroid Build Coastguard Worker        self._binary_test(
703*da0073e9SAndroid Build Coastguard Worker            dtype,
704*da0073e9SAndroid Build Coastguard Worker            foreach_op,
705*da0073e9SAndroid Build Coastguard Worker            native_op,
706*da0073e9SAndroid Build Coastguard Worker            inputs,
707*da0073e9SAndroid Build Coastguard Worker            is_fastpath=False,
708*da0073e9SAndroid Build Coastguard Worker            is_inplace=False,
709*da0073e9SAndroid Build Coastguard Worker            alpha=None,
710*da0073e9SAndroid Build Coastguard Worker            scalar_self_arg=False,
711*da0073e9SAndroid Build Coastguard Worker        )
712*da0073e9SAndroid Build Coastguard Worker        self._binary_test(
713*da0073e9SAndroid Build Coastguard Worker            dtype,
714*da0073e9SAndroid Build Coastguard Worker            foreach_op_,
715*da0073e9SAndroid Build Coastguard Worker            native_op_,
716*da0073e9SAndroid Build Coastguard Worker            inputs,
717*da0073e9SAndroid Build Coastguard Worker            is_fastpath=False,
718*da0073e9SAndroid Build Coastguard Worker            is_inplace=True,
719*da0073e9SAndroid Build Coastguard Worker            alpha=None,
720*da0073e9SAndroid Build Coastguard Worker            scalar_self_arg=False,
721*da0073e9SAndroid Build Coastguard Worker        )
722*da0073e9SAndroid Build Coastguard Worker
723*da0073e9SAndroid Build Coastguard Worker        # different strides
724*da0073e9SAndroid Build Coastguard Worker        tensor1 = torch.zeros(10, 10, device=device, dtype=dtype)
725*da0073e9SAndroid Build Coastguard Worker        tensor2 = torch.ones(10, 10, device=device, dtype=dtype)
726*da0073e9SAndroid Build Coastguard Worker        inputs = ([tensor1], [tensor2.t()])
727*da0073e9SAndroid Build Coastguard Worker        self._binary_test(
728*da0073e9SAndroid Build Coastguard Worker            dtype,
729*da0073e9SAndroid Build Coastguard Worker            foreach_op,
730*da0073e9SAndroid Build Coastguard Worker            native_op,
731*da0073e9SAndroid Build Coastguard Worker            inputs,
732*da0073e9SAndroid Build Coastguard Worker            is_fastpath=False,
733*da0073e9SAndroid Build Coastguard Worker            is_inplace=False,
734*da0073e9SAndroid Build Coastguard Worker            alpha=None,
735*da0073e9SAndroid Build Coastguard Worker            scalar_self_arg=False,
736*da0073e9SAndroid Build Coastguard Worker        )
737*da0073e9SAndroid Build Coastguard Worker        self._binary_test(
738*da0073e9SAndroid Build Coastguard Worker            dtype,
739*da0073e9SAndroid Build Coastguard Worker            foreach_op_,
740*da0073e9SAndroid Build Coastguard Worker            native_op_,
741*da0073e9SAndroid Build Coastguard Worker            inputs,
742*da0073e9SAndroid Build Coastguard Worker            is_fastpath=False,
743*da0073e9SAndroid Build Coastguard Worker            is_inplace=True,
744*da0073e9SAndroid Build Coastguard Worker            alpha=None,
745*da0073e9SAndroid Build Coastguard Worker            scalar_self_arg=False,
746*da0073e9SAndroid Build Coastguard Worker        )
747*da0073e9SAndroid Build Coastguard Worker
748*da0073e9SAndroid Build Coastguard Worker        # non contiguous
749*da0073e9SAndroid Build Coastguard Worker        tensor1 = make_tensor(
750*da0073e9SAndroid Build Coastguard Worker            (5, 2, 1, 3), device=device, dtype=dtype, noncontiguous=True
751*da0073e9SAndroid Build Coastguard Worker        )
752*da0073e9SAndroid Build Coastguard Worker        tensor2 = make_tensor(
753*da0073e9SAndroid Build Coastguard Worker            (5, 2, 1, 3), device=device, dtype=dtype, noncontiguous=True
754*da0073e9SAndroid Build Coastguard Worker        )
755*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(tensor1.is_contiguous())
756*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(tensor2.is_contiguous())
757*da0073e9SAndroid Build Coastguard Worker        inputs = ([tensor1], [tensor2])
758*da0073e9SAndroid Build Coastguard Worker        self._binary_test(
759*da0073e9SAndroid Build Coastguard Worker            dtype,
760*da0073e9SAndroid Build Coastguard Worker            foreach_op,
761*da0073e9SAndroid Build Coastguard Worker            native_op,
762*da0073e9SAndroid Build Coastguard Worker            inputs,
763*da0073e9SAndroid Build Coastguard Worker            is_fastpath=False,
764*da0073e9SAndroid Build Coastguard Worker            is_inplace=False,
765*da0073e9SAndroid Build Coastguard Worker            alpha=None,
766*da0073e9SAndroid Build Coastguard Worker            scalar_self_arg=False,
767*da0073e9SAndroid Build Coastguard Worker        )
768*da0073e9SAndroid Build Coastguard Worker        self._binary_test(
769*da0073e9SAndroid Build Coastguard Worker            dtype,
770*da0073e9SAndroid Build Coastguard Worker            foreach_op_,
771*da0073e9SAndroid Build Coastguard Worker            native_op_,
772*da0073e9SAndroid Build Coastguard Worker            inputs,
773*da0073e9SAndroid Build Coastguard Worker            is_fastpath=False,
774*da0073e9SAndroid Build Coastguard Worker            is_inplace=True,
775*da0073e9SAndroid Build Coastguard Worker            alpha=None,
776*da0073e9SAndroid Build Coastguard Worker            scalar_self_arg=False,
777*da0073e9SAndroid Build Coastguard Worker        )
778*da0073e9SAndroid Build Coastguard Worker
779*da0073e9SAndroid Build Coastguard Worker        # sliced tensor
780*da0073e9SAndroid Build Coastguard Worker        tensor1 = make_tensor((5, 2, 1, 3), device=device, dtype=dtype)
781*da0073e9SAndroid Build Coastguard Worker        tensor2 = make_tensor((5, 2, 1, 3 * 7), device=device, dtype=dtype)[
782*da0073e9SAndroid Build Coastguard Worker            :, :, :, ::7
783*da0073e9SAndroid Build Coastguard Worker        ]
784*da0073e9SAndroid Build Coastguard Worker        inputs = ([tensor1], [tensor2])
785*da0073e9SAndroid Build Coastguard Worker        self._binary_test(
786*da0073e9SAndroid Build Coastguard Worker            dtype,
787*da0073e9SAndroid Build Coastguard Worker            foreach_op,
788*da0073e9SAndroid Build Coastguard Worker            native_op,
789*da0073e9SAndroid Build Coastguard Worker            inputs,
790*da0073e9SAndroid Build Coastguard Worker            is_fastpath=False,
791*da0073e9SAndroid Build Coastguard Worker            is_inplace=False,
792*da0073e9SAndroid Build Coastguard Worker            alpha=None,
793*da0073e9SAndroid Build Coastguard Worker            scalar_self_arg=False,
794*da0073e9SAndroid Build Coastguard Worker        )
795*da0073e9SAndroid Build Coastguard Worker        self._binary_test(
796*da0073e9SAndroid Build Coastguard Worker            dtype,
797*da0073e9SAndroid Build Coastguard Worker            foreach_op_,
798*da0073e9SAndroid Build Coastguard Worker            native_op_,
799*da0073e9SAndroid Build Coastguard Worker            inputs,
800*da0073e9SAndroid Build Coastguard Worker            is_fastpath=False,
801*da0073e9SAndroid Build Coastguard Worker            is_inplace=True,
802*da0073e9SAndroid Build Coastguard Worker            alpha=None,
803*da0073e9SAndroid Build Coastguard Worker            scalar_self_arg=False,
804*da0073e9SAndroid Build Coastguard Worker        )
805*da0073e9SAndroid Build Coastguard Worker
806*da0073e9SAndroid Build Coastguard Worker    @ops(
807*da0073e9SAndroid Build Coastguard Worker        filter(lambda op: op.supports_out, foreach_binary_op_db),
808*da0073e9SAndroid Build Coastguard Worker        dtypes=floating_types_and(torch.half, torch.bfloat16),
809*da0073e9SAndroid Build Coastguard Worker    )
810*da0073e9SAndroid Build Coastguard Worker    def test_binary_op_float_inf_nan(self, device, dtype, op):
811*da0073e9SAndroid Build Coastguard Worker        inputs = (
812*da0073e9SAndroid Build Coastguard Worker            [
813*da0073e9SAndroid Build Coastguard Worker                torch.tensor([float("inf")], device=device, dtype=dtype),
814*da0073e9SAndroid Build Coastguard Worker                torch.tensor([-float("inf")], device=device, dtype=dtype),
815*da0073e9SAndroid Build Coastguard Worker                torch.tensor([float("nan")], device=device, dtype=dtype),
816*da0073e9SAndroid Build Coastguard Worker                torch.tensor([float("nan")], device=device, dtype=dtype),
817*da0073e9SAndroid Build Coastguard Worker            ],
818*da0073e9SAndroid Build Coastguard Worker            [
819*da0073e9SAndroid Build Coastguard Worker                torch.tensor([-float("inf")], device=device, dtype=dtype),
820*da0073e9SAndroid Build Coastguard Worker                torch.tensor([float("inf")], device=device, dtype=dtype),
821*da0073e9SAndroid Build Coastguard Worker                torch.tensor([float("inf")], device=device, dtype=dtype),
822*da0073e9SAndroid Build Coastguard Worker                torch.tensor([float("nan")], device=device, dtype=dtype),
823*da0073e9SAndroid Build Coastguard Worker            ],
824*da0073e9SAndroid Build Coastguard Worker        )
825*da0073e9SAndroid Build Coastguard Worker        op, ref, inplace_op, inplace_ref = self._get_funcs(op)
826*da0073e9SAndroid Build Coastguard Worker        self._binary_test(
827*da0073e9SAndroid Build Coastguard Worker            dtype, op, ref, inputs, True, False, alpha=None, scalar_self_arg=False
828*da0073e9SAndroid Build Coastguard Worker        )
829*da0073e9SAndroid Build Coastguard Worker        self._binary_test(
830*da0073e9SAndroid Build Coastguard Worker            dtype,
831*da0073e9SAndroid Build Coastguard Worker            inplace_op,
832*da0073e9SAndroid Build Coastguard Worker            inplace_ref,
833*da0073e9SAndroid Build Coastguard Worker            inputs,
834*da0073e9SAndroid Build Coastguard Worker            True,
835*da0073e9SAndroid Build Coastguard Worker            True,
836*da0073e9SAndroid Build Coastguard Worker            alpha=None,
837*da0073e9SAndroid Build Coastguard Worker            scalar_self_arg=False,
838*da0073e9SAndroid Build Coastguard Worker        )
839*da0073e9SAndroid Build Coastguard Worker
840*da0073e9SAndroid Build Coastguard Worker    # note: Below three tests (postfixed with `_tensors_on_different_devices`)
841*da0073e9SAndroid Build Coastguard Worker    # checks whether foreach works with lists of tensors on different devices
842*da0073e9SAndroid Build Coastguard Worker    # but tensors of the same index are on the same device, e.g., ['cuda', 'cpu].
843*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
844*da0073e9SAndroid Build Coastguard Worker    @ops(foreach_unary_op_db)
845*da0073e9SAndroid Build Coastguard Worker    def test_unary_op_tensors_on_different_devices(self, device, dtype, op):
846*da0073e9SAndroid Build Coastguard Worker        method, ref, inplace_method, ref_inplace = self._get_funcs(op)
847*da0073e9SAndroid Build Coastguard Worker        # tensors: ['cuda', 'cpu]
848*da0073e9SAndroid Build Coastguard Worker        tensors = next(
849*da0073e9SAndroid Build Coastguard Worker            iter(
850*da0073e9SAndroid Build Coastguard Worker                op.sample_inputs(
851*da0073e9SAndroid Build Coastguard Worker                    device,
852*da0073e9SAndroid Build Coastguard Worker                    dtype,
853*da0073e9SAndroid Build Coastguard Worker                    num_input_tensors=[2],
854*da0073e9SAndroid Build Coastguard Worker                    allow_higher_dtype_scalars=True,
855*da0073e9SAndroid Build Coastguard Worker                )
856*da0073e9SAndroid Build Coastguard Worker            )
857*da0073e9SAndroid Build Coastguard Worker        ).input
858*da0073e9SAndroid Build Coastguard Worker        tensors[1] = tensors[1].to("cpu")
859*da0073e9SAndroid Build Coastguard Worker        if not op.supports_out:
860*da0073e9SAndroid Build Coastguard Worker            try:
861*da0073e9SAndroid Build Coastguard Worker                actual = method((tensors,), False, False, zero_size=False)
862*da0073e9SAndroid Build Coastguard Worker            except RuntimeError as e:
863*da0073e9SAndroid Build Coastguard Worker                with self.assertRaisesRegex(type(e), str(e).splitlines()[0]):
864*da0073e9SAndroid Build Coastguard Worker                    ref((tensors,))
865*da0073e9SAndroid Build Coastguard Worker            else:
866*da0073e9SAndroid Build Coastguard Worker                expected = ref((tensors,))
867*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(expected, actual)
868*da0073e9SAndroid Build Coastguard Worker
869*da0073e9SAndroid Build Coastguard Worker        try:
870*da0073e9SAndroid Build Coastguard Worker            inplace_method((tensors,), False, False, zero_size=False)
871*da0073e9SAndroid Build Coastguard Worker        except RuntimeError as e:
872*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(type(e), str(e).splitlines()[0]):
873*da0073e9SAndroid Build Coastguard Worker                ref_inplace((tensors,))
874*da0073e9SAndroid Build Coastguard Worker        else:
875*da0073e9SAndroid Build Coastguard Worker            if not op.supports_out:
876*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(expected, tensors)
877*da0073e9SAndroid Build Coastguard Worker            else:
878*da0073e9SAndroid Build Coastguard Worker                self.assertEqual([torch.zeros_like(t) for t in tensors], tensors)
879*da0073e9SAndroid Build Coastguard Worker
880*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
881*da0073e9SAndroid Build Coastguard Worker    @ops(filter(lambda op: op.supports_out, foreach_binary_op_db))
882*da0073e9SAndroid Build Coastguard Worker    def test_binary_op_tensors_on_different_devices(self, device, dtype, op):
883*da0073e9SAndroid Build Coastguard Worker        _cuda_tensors = next(
884*da0073e9SAndroid Build Coastguard Worker            iter(
885*da0073e9SAndroid Build Coastguard Worker                op.sample_inputs(
886*da0073e9SAndroid Build Coastguard Worker                    device,
887*da0073e9SAndroid Build Coastguard Worker                    dtype,
888*da0073e9SAndroid Build Coastguard Worker                    num_input_tensors=[2],
889*da0073e9SAndroid Build Coastguard Worker                    same_size=True,
890*da0073e9SAndroid Build Coastguard Worker                    allow_higher_dtype_scalars=True,
891*da0073e9SAndroid Build Coastguard Worker                )
892*da0073e9SAndroid Build Coastguard Worker            )
893*da0073e9SAndroid Build Coastguard Worker        ).input
894*da0073e9SAndroid Build Coastguard Worker        _cpu_tensors = next(
895*da0073e9SAndroid Build Coastguard Worker            iter(
896*da0073e9SAndroid Build Coastguard Worker                op.sample_inputs(
897*da0073e9SAndroid Build Coastguard Worker                    "cpu",
898*da0073e9SAndroid Build Coastguard Worker                    dtype,
899*da0073e9SAndroid Build Coastguard Worker                    num_input_tensors=[2],
900*da0073e9SAndroid Build Coastguard Worker                    same_size=True,
901*da0073e9SAndroid Build Coastguard Worker                    allow_higher_dtype_scalars=True,
902*da0073e9SAndroid Build Coastguard Worker                )
903*da0073e9SAndroid Build Coastguard Worker            )
904*da0073e9SAndroid Build Coastguard Worker        ).input
905*da0073e9SAndroid Build Coastguard Worker        tensors1, tensors2 = list(zip(_cuda_tensors, _cpu_tensors))
906*da0073e9SAndroid Build Coastguard Worker
907*da0073e9SAndroid Build Coastguard Worker        foreach_op, foreach_op_ = op.method_variant, op.inplace_variant
908*da0073e9SAndroid Build Coastguard Worker        native_op, native_op_ = op.ref, op.ref_inplace
909*da0073e9SAndroid Build Coastguard Worker        try:
910*da0073e9SAndroid Build Coastguard Worker            actual = foreach_op(tensors1, tensors2)
911*da0073e9SAndroid Build Coastguard Worker        except RuntimeError as e:
912*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(type(e), re.escape(str(e).splitlines()[0])):
913*da0073e9SAndroid Build Coastguard Worker                [native_op(t1, t2) for t1, t2 in zip(tensors1, tensors2)]
914*da0073e9SAndroid Build Coastguard Worker        else:
915*da0073e9SAndroid Build Coastguard Worker            expected = [native_op(t1, t2) for t1, t2 in zip(tensors1, tensors2)]
916*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(expected, actual)
917*da0073e9SAndroid Build Coastguard Worker        try:
918*da0073e9SAndroid Build Coastguard Worker            foreach_op_(tensors1, tensors2)
919*da0073e9SAndroid Build Coastguard Worker        except RuntimeError as e:
920*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(type(e), re.escape(str(e).splitlines()[0])):
921*da0073e9SAndroid Build Coastguard Worker                [native_op_(t1, t2) for t1, t2 in zip(tensors1, tensors2)]
922*da0073e9SAndroid Build Coastguard Worker        else:
923*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(actual, tensors1)
924*da0073e9SAndroid Build Coastguard Worker
925*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
926*da0073e9SAndroid Build Coastguard Worker    @ops(foreach_pointwise_op_db, allowed_dtypes=floating_types())
927*da0073e9SAndroid Build Coastguard Worker    def test_pointwise_op_tensors_on_different_devices(self, device, dtype, op):
928*da0073e9SAndroid Build Coastguard Worker        # tensors1: ['cuda', 'cpu]
929*da0073e9SAndroid Build Coastguard Worker        # tensors2: ['cuda', 'cpu]
930*da0073e9SAndroid Build Coastguard Worker        # tensors3: ['cuda', 'cpu]
931*da0073e9SAndroid Build Coastguard Worker        # first tensorlist is zero-size when float32
932*da0073e9SAndroid Build Coastguard Worker        _cuda_tensors = list(
933*da0073e9SAndroid Build Coastguard Worker            op.sample_inputs(
934*da0073e9SAndroid Build Coastguard Worker                device,
935*da0073e9SAndroid Build Coastguard Worker                dtype,
936*da0073e9SAndroid Build Coastguard Worker                num_input_tensors=[3],
937*da0073e9SAndroid Build Coastguard Worker                same_size=True,
938*da0073e9SAndroid Build Coastguard Worker                allow_higher_dtype_scalars=True,
939*da0073e9SAndroid Build Coastguard Worker            )
940*da0073e9SAndroid Build Coastguard Worker        )[int(dtype == torch.float32)].input
941*da0073e9SAndroid Build Coastguard Worker        _cpu_tensors = next(
942*da0073e9SAndroid Build Coastguard Worker            iter(
943*da0073e9SAndroid Build Coastguard Worker                op.sample_inputs(
944*da0073e9SAndroid Build Coastguard Worker                    "cpu",
945*da0073e9SAndroid Build Coastguard Worker                    dtype,
946*da0073e9SAndroid Build Coastguard Worker                    num_input_tensors=[3],
947*da0073e9SAndroid Build Coastguard Worker                    same_size=True,
948*da0073e9SAndroid Build Coastguard Worker                    allow_higher_dtype_scalars=True,
949*da0073e9SAndroid Build Coastguard Worker                )
950*da0073e9SAndroid Build Coastguard Worker            )
951*da0073e9SAndroid Build Coastguard Worker        ).input
952*da0073e9SAndroid Build Coastguard Worker        tensors1, tensors2, tensors3 = list(zip(_cuda_tensors, _cpu_tensors))
953*da0073e9SAndroid Build Coastguard Worker
954*da0073e9SAndroid Build Coastguard Worker        foreach_op, foreach_op_, native_op = (
955*da0073e9SAndroid Build Coastguard Worker            op.method_variant,
956*da0073e9SAndroid Build Coastguard Worker            op.inplace_variant,
957*da0073e9SAndroid Build Coastguard Worker            op.ref,
958*da0073e9SAndroid Build Coastguard Worker        )
959*da0073e9SAndroid Build Coastguard Worker        actual = foreach_op(tensors1, tensors2, tensors3)
960*da0073e9SAndroid Build Coastguard Worker        expected = [native_op(*_cuda_tensors), native_op(*_cpu_tensors)]
961*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expected, actual)
962*da0073e9SAndroid Build Coastguard Worker
963*da0073e9SAndroid Build Coastguard Worker        # note(mkozuki): Limiting dtypes to FP32&FP64, we can safely run inplace ops.
964*da0073e9SAndroid Build Coastguard Worker        foreach_op_(tensors1, tensors2, tensors3)
965*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expected, tensors1)
966*da0073e9SAndroid Build Coastguard Worker
967*da0073e9SAndroid Build Coastguard Worker    # note: BFloat16 has the same number of exponent bits as FP32
968*da0073e9SAndroid Build Coastguard Worker    # so if squared L2 norm overflows in BF16, then it also overflows in FP32.
969*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
970*da0073e9SAndroid Build Coastguard Worker    @ops(
971*da0073e9SAndroid Build Coastguard Worker        [o for o in foreach_reduce_op_db if "norm" in o.name],
972*da0073e9SAndroid Build Coastguard Worker        allowed_dtypes=(torch.half, torch.bfloat16),
973*da0073e9SAndroid Build Coastguard Worker    )
974*da0073e9SAndroid Build Coastguard Worker    def test_foreach_l2_large_value_input(self, device, dtype, op):
975*da0073e9SAndroid Build Coastguard Worker        ord, N = 2, 10
976*da0073e9SAndroid Build Coastguard Worker        max_value = torch.finfo(dtype).max
977*da0073e9SAndroid Build Coastguard Worker        scaler = torch.tensor([max_value]).sqrt().to(device=device, dtype=dtype)
978*da0073e9SAndroid Build Coastguard Worker        inputs = (
979*da0073e9SAndroid Build Coastguard Worker            [
980*da0073e9SAndroid Build Coastguard Worker                t * scaler
981*da0073e9SAndroid Build Coastguard Worker                for t in next(
982*da0073e9SAndroid Build Coastguard Worker                    iter(
983*da0073e9SAndroid Build Coastguard Worker                        op.sample_inputs(
984*da0073e9SAndroid Build Coastguard Worker                            device,
985*da0073e9SAndroid Build Coastguard Worker                            dtype,
986*da0073e9SAndroid Build Coastguard Worker                            requries_grad=True,
987*da0073e9SAndroid Build Coastguard Worker                            num_input_tensors=[N],
988*da0073e9SAndroid Build Coastguard Worker                            low=1,
989*da0073e9SAndroid Build Coastguard Worker                        )
990*da0073e9SAndroid Build Coastguard Worker                    )
991*da0073e9SAndroid Build Coastguard Worker                ).input
992*da0073e9SAndroid Build Coastguard Worker            ][:-1],
993*da0073e9SAndroid Build Coastguard Worker        )
994*da0073e9SAndroid Build Coastguard Worker        # make sure that the min. of squared L2 norm value per tensor is greater than the max value of `dtype`.
995*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(scaler * scaler * N > max_value)
996*da0073e9SAndroid Build Coastguard Worker        fn, ref_fn, *_ = self._get_funcs(op)
997*da0073e9SAndroid Build Coastguard Worker        actual = fn(
998*da0073e9SAndroid Build Coastguard Worker            inputs, is_cuda=True, expect_fastpath=True, ord=ord, zero_size=False
999*da0073e9SAndroid Build Coastguard Worker        )
1000*da0073e9SAndroid Build Coastguard Worker        expect = ref_fn(inputs, ord=ord)
1001*da0073e9SAndroid Build Coastguard Worker
1002*da0073e9SAndroid Build Coastguard Worker        if dtype == torch.float16:
1003*da0073e9SAndroid Build Coastguard Worker            # making sure the reference L2 norm values are in the range of FP16.
1004*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(any(torch.isinf(e) for e in expect))
1005*da0073e9SAndroid Build Coastguard Worker        else:
1006*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(
1007*da0073e9SAndroid Build Coastguard Worker                all(
1008*da0073e9SAndroid Build Coastguard Worker                    inputs[0][i].numel() == 0 or torch.isinf(e)
1009*da0073e9SAndroid Build Coastguard Worker                    for i, e in enumerate(expect)
1010*da0073e9SAndroid Build Coastguard Worker                )
1011*da0073e9SAndroid Build Coastguard Worker            )
1012*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expect, actual, equal_nan=False)
1013*da0073e9SAndroid Build Coastguard Worker
1014*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
1015*da0073e9SAndroid Build Coastguard Worker    @ops(foreach_reduce_op_db, allowed_dtypes=floating_types())
1016*da0073e9SAndroid Build Coastguard Worker    @parametrize("use_cuda_graph", (False, True))
1017*da0073e9SAndroid Build Coastguard Worker    def test_big_num_tensors(self, device, dtype, op, use_cuda_graph):
1018*da0073e9SAndroid Build Coastguard Worker        N = 600
1019*da0073e9SAndroid Build Coastguard Worker        tensorlist = [
1020*da0073e9SAndroid Build Coastguard Worker            make_tensor((2, 3), dtype=dtype, device=device, noncontiguous=False)
1021*da0073e9SAndroid Build Coastguard Worker            for _ in range(N)
1022*da0073e9SAndroid Build Coastguard Worker        ]
1023*da0073e9SAndroid Build Coastguard Worker        fn, ref_fn, *_ = self._get_funcs(op)
1024*da0073e9SAndroid Build Coastguard Worker
1025*da0073e9SAndroid Build Coastguard Worker        import math
1026*da0073e9SAndroid Build Coastguard Worker
1027*da0073e9SAndroid Build Coastguard Worker        if op.name == "_foreach_norm":
1028*da0073e9SAndroid Build Coastguard Worker            ords = (1, 2, math.inf)
1029*da0073e9SAndroid Build Coastguard Worker        else:
1030*da0073e9SAndroid Build Coastguard Worker            ords = (None,)
1031*da0073e9SAndroid Build Coastguard Worker
1032*da0073e9SAndroid Build Coastguard Worker        for ord in ords:
1033*da0073e9SAndroid Build Coastguard Worker            kwargs = {"ord": ord} if ord else {}
1034*da0073e9SAndroid Build Coastguard Worker            if not use_cuda_graph:
1035*da0073e9SAndroid Build Coastguard Worker                actual = fn(
1036*da0073e9SAndroid Build Coastguard Worker                    inputs=[tensorlist],
1037*da0073e9SAndroid Build Coastguard Worker                    is_cuda=True,
1038*da0073e9SAndroid Build Coastguard Worker                    expect_fastpath=True,
1039*da0073e9SAndroid Build Coastguard Worker                    zero_size=False,
1040*da0073e9SAndroid Build Coastguard Worker                    **kwargs,
1041*da0073e9SAndroid Build Coastguard Worker                )
1042*da0073e9SAndroid Build Coastguard Worker            else:
1043*da0073e9SAndroid Build Coastguard Worker                # When using CUDA graphs and the tensor metadata doesn't fit in
1044*da0073e9SAndroid Build Coastguard Worker                # the static kernel argument space, multi_tensor_apply creates
1045*da0073e9SAndroid Build Coastguard Worker                # the launch arguments once, uses cudaUserObject_t to tie its
1046*da0073e9SAndroid Build Coastguard Worker                # lifetime to the graph, and reuses it throughout replays. This
1047*da0073e9SAndroid Build Coastguard Worker                # test verifies multi_tensor_apply's behavior in the scenario.
1048*da0073e9SAndroid Build Coastguard Worker                g = torch.cuda.CUDAGraph()
1049*da0073e9SAndroid Build Coastguard Worker                with torch.cuda.graph(g):
1050*da0073e9SAndroid Build Coastguard Worker                    actual = fn.func(tensorlist, **kwargs)
1051*da0073e9SAndroid Build Coastguard Worker                g.replay()
1052*da0073e9SAndroid Build Coastguard Worker            expect = ref_fn(inputs=[tensorlist], **kwargs)
1053*da0073e9SAndroid Build Coastguard Worker
1054*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(expect, actual, equal_nan=True)
1055*da0073e9SAndroid Build Coastguard Worker
1056*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
1057*da0073e9SAndroid Build Coastguard Worker    @ops(foreach_reduce_op_db)
1058*da0073e9SAndroid Build Coastguard Worker    def test_foreach_reduce_large_input(self, device, dtype, op):
1059*da0073e9SAndroid Build Coastguard Worker        # test inputs larger than kChunkSize = 65536
1060*da0073e9SAndroid Build Coastguard Worker        N = 65536 * 2
1061*da0073e9SAndroid Build Coastguard Worker        disable_fastpath = False
1062*da0073e9SAndroid Build Coastguard Worker        kwargs = {}
1063*da0073e9SAndroid Build Coastguard Worker        if op.name == "_foreach_norm":
1064*da0073e9SAndroid Build Coastguard Worker            ord = 2
1065*da0073e9SAndroid Build Coastguard Worker            disable_fastpath = not (
1066*da0073e9SAndroid Build Coastguard Worker                ord in (1, 2)
1067*da0073e9SAndroid Build Coastguard Worker                and dtype in floating_types_and(torch.half, torch.bfloat16)
1068*da0073e9SAndroid Build Coastguard Worker            )
1069*da0073e9SAndroid Build Coastguard Worker            kwargs["ord"] = ord
1070*da0073e9SAndroid Build Coastguard Worker
1071*da0073e9SAndroid Build Coastguard Worker        inputs = ([make_tensor((N,), dtype=dtype, device=device, noncontiguous=False)],)
1072*da0073e9SAndroid Build Coastguard Worker        wrapped_op, ref, _, _ = self._get_funcs(op)
1073*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
1074*da0073e9SAndroid Build Coastguard Worker            ref(inputs, **kwargs),
1075*da0073e9SAndroid Build Coastguard Worker            wrapped_op(
1076*da0073e9SAndroid Build Coastguard Worker                inputs, self.is_cuda, not disable_fastpath, zero_size=False, **kwargs
1077*da0073e9SAndroid Build Coastguard Worker            ),
1078*da0073e9SAndroid Build Coastguard Worker        )
1079*da0073e9SAndroid Build Coastguard Worker
1080*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
1081*da0073e9SAndroid Build Coastguard Worker    @ops(
1082*da0073e9SAndroid Build Coastguard Worker        foreach_unary_op_db
1083*da0073e9SAndroid Build Coastguard Worker        + foreach_binary_op_db
1084*da0073e9SAndroid Build Coastguard Worker        + foreach_pointwise_op_db
1085*da0073e9SAndroid Build Coastguard Worker        + foreach_other_op_db,
1086*da0073e9SAndroid Build Coastguard Worker        dtypes=(torch.float,),
1087*da0073e9SAndroid Build Coastguard Worker    )
1088*da0073e9SAndroid Build Coastguard Worker    def test_inplace_foreach_leaf_check_and_grad_fn(self, device, dtype, op):
1089*da0073e9SAndroid Build Coastguard Worker        inplace_op = op.inplace_variant
1090*da0073e9SAndroid Build Coastguard Worker        if inplace_op is None:
1091*da0073e9SAndroid Build Coastguard Worker            self.skipTest("no in-place op available")
1092*da0073e9SAndroid Build Coastguard Worker
1093*da0073e9SAndroid Build Coastguard Worker        sample = next(
1094*da0073e9SAndroid Build Coastguard Worker            iter(
1095*da0073e9SAndroid Build Coastguard Worker                op.sample_inputs(
1096*da0073e9SAndroid Build Coastguard Worker                    dtype=dtype, device=device, num_input_tensors=[2], same_size=True
1097*da0073e9SAndroid Build Coastguard Worker                )
1098*da0073e9SAndroid Build Coastguard Worker            )
1099*da0073e9SAndroid Build Coastguard Worker        )
1100*da0073e9SAndroid Build Coastguard Worker        sample.input[0].requires_grad_(True)
1101*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "a leaf Variable that requires grad"):
1102*da0073e9SAndroid Build Coastguard Worker            inplace_op(sample.input, *sample.args)
1103*da0073e9SAndroid Build Coastguard Worker        sample.input[1].requires_grad_(True)
1104*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "a leaf Variable that requires grad"):
1105*da0073e9SAndroid Build Coastguard Worker            inplace_op(sample.input, *sample.args)
1106*da0073e9SAndroid Build Coastguard Worker
1107*da0073e9SAndroid Build Coastguard Worker        _tensors = [
1108*da0073e9SAndroid Build Coastguard Worker            t.clone().detach().requires_grad_(i == 0)
1109*da0073e9SAndroid Build Coastguard Worker            for i, t in enumerate(sample.input)
1110*da0073e9SAndroid Build Coastguard Worker        ]
1111*da0073e9SAndroid Build Coastguard Worker        tensors = [t.clone() for t in _tensors]
1112*da0073e9SAndroid Build Coastguard Worker        inplace_op(tensors, *sample.args)
1113*da0073e9SAndroid Build Coastguard Worker        self.assertIsNotNone(tensors[0].grad_fn)
1114*da0073e9SAndroid Build Coastguard Worker        self.assertIsNone(tensors[1].grad_fn)
1115*da0073e9SAndroid Build Coastguard Worker
1116*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
1117*da0073e9SAndroid Build Coastguard Worker    @ops(
1118*da0073e9SAndroid Build Coastguard Worker        filter(
1119*da0073e9SAndroid Build Coastguard Worker            lambda op: op.supports_out,
1120*da0073e9SAndroid Build Coastguard Worker            foreach_unary_op_db
1121*da0073e9SAndroid Build Coastguard Worker            + foreach_binary_op_db
1122*da0073e9SAndroid Build Coastguard Worker            + foreach_pointwise_op_db
1123*da0073e9SAndroid Build Coastguard Worker            + foreach_other_op_db,
1124*da0073e9SAndroid Build Coastguard Worker        ),
1125*da0073e9SAndroid Build Coastguard Worker        dtypes=(torch.float,),
1126*da0073e9SAndroid Build Coastguard Worker    )
1127*da0073e9SAndroid Build Coastguard Worker    def test_outplace_with_invalid_grads(self, device, dtype, op):
1128*da0073e9SAndroid Build Coastguard Worker        func, *_ = self._get_funcs(op)
1129*da0073e9SAndroid Build Coastguard Worker        sample = next(
1130*da0073e9SAndroid Build Coastguard Worker            iter(
1131*da0073e9SAndroid Build Coastguard Worker                op.sample_inputs(
1132*da0073e9SAndroid Build Coastguard Worker                    dtype=dtype,
1133*da0073e9SAndroid Build Coastguard Worker                    device=device,
1134*da0073e9SAndroid Build Coastguard Worker                    requires_grad=True,
1135*da0073e9SAndroid Build Coastguard Worker                    num_input_tensors=[2],
1136*da0073e9SAndroid Build Coastguard Worker                    same_size=True,
1137*da0073e9SAndroid Build Coastguard Worker                )
1138*da0073e9SAndroid Build Coastguard Worker            )
1139*da0073e9SAndroid Build Coastguard Worker        )
1140*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(all(t.requires_grad for t in sample.input))
1141*da0073e9SAndroid Build Coastguard Worker        (out1, out2) = func(
1142*da0073e9SAndroid Build Coastguard Worker            [sample.input, *sample.args],
1143*da0073e9SAndroid Build Coastguard Worker            is_cuda=False,
1144*da0073e9SAndroid Build Coastguard Worker            expect_fastpath=False,
1145*da0073e9SAndroid Build Coastguard Worker            **sample.kwargs,
1146*da0073e9SAndroid Build Coastguard Worker        )
1147*da0073e9SAndroid Build Coastguard Worker        out1.backward(torch.ones_like(out1))
1148*da0073e9SAndroid Build Coastguard Worker        self.assertIsNotNone(sample.input[0].grad)
1149*da0073e9SAndroid Build Coastguard Worker        self.assertIsNone(sample.input[1].grad)
1150*da0073e9SAndroid Build Coastguard Worker
1151*da0073e9SAndroid Build Coastguard Worker    @ops(
1152*da0073e9SAndroid Build Coastguard Worker        filter(
1153*da0073e9SAndroid Build Coastguard Worker            lambda op: op.backward_requires_result,
1154*da0073e9SAndroid Build Coastguard Worker            foreach_unary_op_db
1155*da0073e9SAndroid Build Coastguard Worker            + foreach_binary_op_db
1156*da0073e9SAndroid Build Coastguard Worker            + foreach_pointwise_op_db
1157*da0073e9SAndroid Build Coastguard Worker            + foreach_other_op_db,
1158*da0073e9SAndroid Build Coastguard Worker        ),
1159*da0073e9SAndroid Build Coastguard Worker        dtypes=(torch.float32,),
1160*da0073e9SAndroid Build Coastguard Worker    )
1161*da0073e9SAndroid Build Coastguard Worker    def test_lifetime_of_grad_fn_when_result_is_saved(self, device, dtype, op):
1162*da0073e9SAndroid Build Coastguard Worker        def get_ref(func, sample):
1163*da0073e9SAndroid Build Coastguard Worker            class Foo:
1164*da0073e9SAndroid Build Coastguard Worker                pass
1165*da0073e9SAndroid Build Coastguard Worker
1166*da0073e9SAndroid Build Coastguard Worker            out = func(
1167*da0073e9SAndroid Build Coastguard Worker                (sample.input, *sample.args),
1168*da0073e9SAndroid Build Coastguard Worker                is_cuda=False,
1169*da0073e9SAndroid Build Coastguard Worker                expect_fastpath=False,
1170*da0073e9SAndroid Build Coastguard Worker                **sample.kwargs,
1171*da0073e9SAndroid Build Coastguard Worker            )
1172*da0073e9SAndroid Build Coastguard Worker            foo = Foo()
1173*da0073e9SAndroid Build Coastguard Worker            meta_dict = out[0].grad_fn.metadata
1174*da0073e9SAndroid Build Coastguard Worker            meta_dict[0] = foo
1175*da0073e9SAndroid Build Coastguard Worker            ref = weakref.ref(foo)
1176*da0073e9SAndroid Build Coastguard Worker            return out, ref
1177*da0073e9SAndroid Build Coastguard Worker
1178*da0073e9SAndroid Build Coastguard Worker        def _test(func, sample):
1179*da0073e9SAndroid Build Coastguard Worker            out, ref = get_ref(func, sample)
1180*da0073e9SAndroid Build Coastguard Worker            self.assertIsNotNone(ref())
1181*da0073e9SAndroid Build Coastguard Worker            del out
1182*da0073e9SAndroid Build Coastguard Worker            self.assertIsNone(ref())
1183*da0073e9SAndroid Build Coastguard Worker
1184*da0073e9SAndroid Build Coastguard Worker        func = self._get_funcs(op)[0]
1185*da0073e9SAndroid Build Coastguard Worker        for sample in op.sample_inputs(
1186*da0073e9SAndroid Build Coastguard Worker            device, dtype, requires_grad=True, num_input_tensors=[1]
1187*da0073e9SAndroid Build Coastguard Worker        ):
1188*da0073e9SAndroid Build Coastguard Worker            for key in ("is_fastpath", "disable_fastpath"):
1189*da0073e9SAndroid Build Coastguard Worker                if key in sample.kwargs:
1190*da0073e9SAndroid Build Coastguard Worker                    del sample.kwargs[key]
1191*da0073e9SAndroid Build Coastguard Worker            # note: `_foreach_pow.Scalar` and `_foreach_pow.ScalarList` don't depend on `result`
1192*da0073e9SAndroid Build Coastguard Worker            # see: https://github.com/pytorch/pytorch/blob/5403c777/tools/autograd/derivatives.yaml#L3048-L3049
1193*da0073e9SAndroid Build Coastguard Worker            if op.name == "_foreach_pow":
1194*da0073e9SAndroid Build Coastguard Worker                if (
1195*da0073e9SAndroid Build Coastguard Worker                    isinstance(sample.args[0], list)
1196*da0073e9SAndroid Build Coastguard Worker                    and isinstance(sample.args[0][0], Number)
1197*da0073e9SAndroid Build Coastguard Worker                ) or (
1198*da0073e9SAndroid Build Coastguard Worker                    isinstance(sample.args[0], Number)
1199*da0073e9SAndroid Build Coastguard Worker                    and not isinstance(sample.args[0], float)
1200*da0073e9SAndroid Build Coastguard Worker                ):
1201*da0073e9SAndroid Build Coastguard Worker                    continue
1202*da0073e9SAndroid Build Coastguard Worker                if isinstance(sample.args[0], float):
1203*da0073e9SAndroid Build Coastguard Worker                    new_args = (sample.input,)
1204*da0073e9SAndroid Build Coastguard Worker                    sample.input = sample.args[0]
1205*da0073e9SAndroid Build Coastguard Worker                    sample.args = new_args
1206*da0073e9SAndroid Build Coastguard Worker            _test(func, sample)
1207*da0073e9SAndroid Build Coastguard Worker
1208*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
1209*da0073e9SAndroid Build Coastguard Worker    def test_tensors_grouping(self):
1210*da0073e9SAndroid Build Coastguard Worker        num_tensors_per_list = 10
1211*da0073e9SAndroid Build Coastguard Worker        num_devices = torch.cuda.device_count()
1212*da0073e9SAndroid Build Coastguard Worker        dtypes = (torch.float16, torch.float32, torch.float64)
1213*da0073e9SAndroid Build Coastguard Worker        list1 = [
1214*da0073e9SAndroid Build Coastguard Worker            torch.tensor(
1215*da0073e9SAndroid Build Coastguard Worker                i,
1216*da0073e9SAndroid Build Coastguard Worker                device=torch.device("cuda", random.randint(0, num_devices - 1)),
1217*da0073e9SAndroid Build Coastguard Worker                dtype=dtypes[random.randint(0, 2)],
1218*da0073e9SAndroid Build Coastguard Worker            )
1219*da0073e9SAndroid Build Coastguard Worker            for i in range(num_tensors_per_list)
1220*da0073e9SAndroid Build Coastguard Worker        ]
1221*da0073e9SAndroid Build Coastguard Worker        list2 = [None for _ in list1]
1222*da0073e9SAndroid Build Coastguard Worker        list3 = [torch.rand_like(t) for t in list1]
1223*da0073e9SAndroid Build Coastguard Worker        nested_tensorlists = [list1, list2, list3]
1224*da0073e9SAndroid Build Coastguard Worker        grouped_tensors = torch.utils._foreach_utils._group_tensors_by_device_and_dtype(
1225*da0073e9SAndroid Build Coastguard Worker            nested_tensorlists, with_indices=True
1226*da0073e9SAndroid Build Coastguard Worker        )
1227*da0073e9SAndroid Build Coastguard Worker        num_tensors_seen = 0
1228*da0073e9SAndroid Build Coastguard Worker        for (device, dtype), ([l1, l2, l3], indices) in grouped_tensors.items():
1229*da0073e9SAndroid Build Coastguard Worker            for t in itertools.chain(l1, l3):
1230*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(t.device, device)
1231*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(t.dtype, dtype)
1232*da0073e9SAndroid Build Coastguard Worker                num_tensors_seen += 1
1233*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(len(l1), len(l2))
1234*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(all(p is None for p in l2))
1235*da0073e9SAndroid Build Coastguard Worker            for i, index in enumerate(indices):
1236*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(l1[i], list1[index])
1237*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(l2[i], list2[index])
1238*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(l3[i], list3[index])
1239*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(num_tensors_seen, 2 * num_tensors_per_list)
1240*da0073e9SAndroid Build Coastguard Worker
1241*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
1242*da0073e9SAndroid Build Coastguard Worker    def test_0dim_tensor_overload_cpu_ok(self):
1243*da0073e9SAndroid Build Coastguard Worker        tensors = [torch.ones((), device="cuda", dtype=torch.float32) for _ in range(2)]
1244*da0073e9SAndroid Build Coastguard Worker        scalar_cpu_tensor = torch.tensor(4.0, device="cpu")
1245*da0073e9SAndroid Build Coastguard Worker
1246*da0073e9SAndroid Build Coastguard Worker        # For mul and div, the scalar is allowed to be on CPU too
1247*da0073e9SAndroid Build Coastguard Worker        actual = torch._foreach_mul(tensors, scalar_cpu_tensor)
1248*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(actual, [t.mul(scalar_cpu_tensor) for t in tensors])
1249*da0073e9SAndroid Build Coastguard Worker        actual = torch._foreach_div(tensors, scalar_cpu_tensor)
1250*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(actual, [t.div(scalar_cpu_tensor) for t in tensors])
1251*da0073e9SAndroid Build Coastguard Worker
1252*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
1253*da0073e9SAndroid Build Coastguard Worker    def test_div_reciprocal(self):
1254*da0073e9SAndroid Build Coastguard Worker        expect_m, expect_e = torch.frexp(
1255*da0073e9SAndroid Build Coastguard Worker            torch.div(torch.tensor(0.1, device="cuda"), 10.0)
1256*da0073e9SAndroid Build Coastguard Worker        )
1257*da0073e9SAndroid Build Coastguard Worker        actual_m, actual_e = torch.frexp(
1258*da0073e9SAndroid Build Coastguard Worker            torch._foreach_div([torch.tensor(0.1, device="cuda")], [10.0])[0]
1259*da0073e9SAndroid Build Coastguard Worker        )
1260*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expect_m, actual_m)
1261*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expect_e, actual_e)
1262*da0073e9SAndroid Build Coastguard Worker
1263*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
1264*da0073e9SAndroid Build Coastguard Worker    def test_0dim_tensor_overload_exception(self):
1265*da0073e9SAndroid Build Coastguard Worker        # check exceptions of fast path
1266*da0073e9SAndroid Build Coastguard Worker        tensors = [
1267*da0073e9SAndroid Build Coastguard Worker            make_tensor((2, 2), dtype=torch.float, device="cuda") for _ in range(2)
1268*da0073e9SAndroid Build Coastguard Worker        ]
1269*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "scalar tensor expected to be on"):
1270*da0073e9SAndroid Build Coastguard Worker            torch._foreach_add(tensors, torch.tensor(1.0, device="cpu"), alpha=1.0)
1271*da0073e9SAndroid Build Coastguard Worker
1272*da0073e9SAndroid Build Coastguard Worker        tensors = [
1273*da0073e9SAndroid Build Coastguard Worker            make_tensor((2, 2), dtype=torch.float, device=d) for d in ("cpu", "cuda")
1274*da0073e9SAndroid Build Coastguard Worker        ]
1275*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
1276*da0073e9SAndroid Build Coastguard Worker            RuntimeError, "scalar tensor expected to be 0 dim but"
1277*da0073e9SAndroid Build Coastguard Worker        ):
1278*da0073e9SAndroid Build Coastguard Worker            torch._foreach_mul(tensors, torch.tensor([1.0, 1.0], device="cuda"))
1279*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
1280*da0073e9SAndroid Build Coastguard Worker            RuntimeError, "scalar tensor expected to be 0 dim but"
1281*da0073e9SAndroid Build Coastguard Worker        ):
1282*da0073e9SAndroid Build Coastguard Worker            torch._foreach_add(tensors, torch.tensor([1.0, 1.0], device="cuda"))
1283*da0073e9SAndroid Build Coastguard Worker
1284*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
1285*da0073e9SAndroid Build Coastguard Worker    @ops(filter(lambda op: op.name == "_foreach_copy", foreach_binary_op_db))
1286*da0073e9SAndroid Build Coastguard Worker    def test_foreach_copy_with_multi_device_inputs(self, device, dtype, op):
1287*da0073e9SAndroid Build Coastguard Worker        foreach_copy_ = op.inplace_variant
1288*da0073e9SAndroid Build Coastguard Worker        copy_ = op.ref_inplace
1289*da0073e9SAndroid Build Coastguard Worker        for non_blocking in (False, True):
1290*da0073e9SAndroid Build Coastguard Worker            for sample in op.sample_inputs(
1291*da0073e9SAndroid Build Coastguard Worker                device, dtype, noncontiguous=False, allow_higher_dtype_scalars=True
1292*da0073e9SAndroid Build Coastguard Worker            ):
1293*da0073e9SAndroid Build Coastguard Worker                with torch.no_grad():
1294*da0073e9SAndroid Build Coastguard Worker                    ref_input = [t.clone().detach() for t in sample.input]
1295*da0073e9SAndroid Build Coastguard Worker                foreach_copy_(sample.input, sample.args[0], non_blocking)
1296*da0073e9SAndroid Build Coastguard Worker                for t, s in zip(ref_input, sample.args[0]):
1297*da0073e9SAndroid Build Coastguard Worker                    copy_(t, s, non_blocking)
1298*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(sample.input, ref_input)
1299*da0073e9SAndroid Build Coastguard Worker                if torch.cuda.device_count() > 1:
1300*da0073e9SAndroid Build Coastguard Worker                    device = torch.device("cuda", 1)
1301*da0073e9SAndroid Build Coastguard Worker                    rhs_tensors = [t.to(device) for t in sample.args[0]]
1302*da0073e9SAndroid Build Coastguard Worker                    foreach_copy_(sample.input, rhs_tensors, non_blocking)
1303*da0073e9SAndroid Build Coastguard Worker                    for t, s in zip(ref_input, rhs_tensors):
1304*da0073e9SAndroid Build Coastguard Worker                        copy_(t, s, non_blocking)
1305*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(ref_input, sample.input)
1306*da0073e9SAndroid Build Coastguard Worker
1307*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
1308*da0073e9SAndroid Build Coastguard Worker    @ops(filter(lambda op: op.name == "_foreach_copy", foreach_binary_op_db))
1309*da0073e9SAndroid Build Coastguard Worker    def test_foreach_copy_with_multi_dtypes(self, device, dtype, op):
1310*da0073e9SAndroid Build Coastguard Worker        # check (a) multi_tensor_apply is called and (b) numerical parity with for-loop and Tensor.copy_
1311*da0073e9SAndroid Build Coastguard Worker        foreach_copy_ = ForeachFuncWrapper(op.inplace_variant)
1312*da0073e9SAndroid Build Coastguard Worker        for sample in op.sample_inputs(
1313*da0073e9SAndroid Build Coastguard Worker            device, dtype, noncontiguous=False, allow_higher_dtype_scalars=True
1314*da0073e9SAndroid Build Coastguard Worker        ):
1315*da0073e9SAndroid Build Coastguard Worker            for src_dtype in floating_types_and(torch.half, torch.bfloat16):
1316*da0073e9SAndroid Build Coastguard Worker                if src_dtype == dtype:
1317*da0073e9SAndroid Build Coastguard Worker                    continue
1318*da0073e9SAndroid Build Coastguard Worker                self_tensors = [t.clone() for t in sample.input]
1319*da0073e9SAndroid Build Coastguard Worker                src_tensors = [t.to(src_dtype) for t in self_tensors]
1320*da0073e9SAndroid Build Coastguard Worker                out = foreach_copy_(
1321*da0073e9SAndroid Build Coastguard Worker                    (self_tensors, src_tensors), is_cuda=True, expect_fastpath=True
1322*da0073e9SAndroid Build Coastguard Worker                )
1323*da0073e9SAndroid Build Coastguard Worker                ref_out = [
1324*da0073e9SAndroid Build Coastguard Worker                    torch.empty_like(t).copy_(s)
1325*da0073e9SAndroid Build Coastguard Worker                    for t, s in zip(self_tensors, src_tensors)
1326*da0073e9SAndroid Build Coastguard Worker                ]
1327*da0073e9SAndroid Build Coastguard Worker                for t, ref_t in zip(out, ref_out):
1328*da0073e9SAndroid Build Coastguard Worker                    self.assertTrue(torch.equal(t, ref_t))
1329*da0073e9SAndroid Build Coastguard Worker
1330*da0073e9SAndroid Build Coastguard Worker    # Test reverse-mode & forward-mode AD if supported.
1331*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
1332*da0073e9SAndroid Build Coastguard Worker    @ops(
1333*da0073e9SAndroid Build Coastguard Worker        foreach_unary_op_db
1334*da0073e9SAndroid Build Coastguard Worker        + foreach_binary_op_db
1335*da0073e9SAndroid Build Coastguard Worker        + foreach_pointwise_op_db
1336*da0073e9SAndroid Build Coastguard Worker        + foreach_reduce_op_db
1337*da0073e9SAndroid Build Coastguard Worker        + foreach_other_op_db,
1338*da0073e9SAndroid Build Coastguard Worker        dtypes=OpDTypes.supported,
1339*da0073e9SAndroid Build Coastguard Worker        allowed_dtypes=(torch.float64, torch.complex128),
1340*da0073e9SAndroid Build Coastguard Worker    )
1341*da0073e9SAndroid Build Coastguard Worker    @parametrize(
1342*da0073e9SAndroid Build Coastguard Worker        "inplace", (False, True), name_fn=lambda x: "inplace" if x else "outplace"
1343*da0073e9SAndroid Build Coastguard Worker    )
1344*da0073e9SAndroid Build Coastguard Worker    def test_autodiff(self, device, dtype, op, inplace):
1345*da0073e9SAndroid Build Coastguard Worker        if (not inplace) and not op.supports_out:
1346*da0073e9SAndroid Build Coastguard Worker            self.skipTest("out-of-place not implemented")
1347*da0073e9SAndroid Build Coastguard Worker        if inplace and op.has_no_in_place:
1348*da0073e9SAndroid Build Coastguard Worker            self.skipTest("in-place not implemented")
1349*da0073e9SAndroid Build Coastguard Worker        if not (
1350*da0073e9SAndroid Build Coastguard Worker            op.supports_autograd
1351*da0073e9SAndroid Build Coastguard Worker            or op.supports_inplace_autograd
1352*da0073e9SAndroid Build Coastguard Worker            or op.supports_forward_ad
1353*da0073e9SAndroid Build Coastguard Worker        ):
1354*da0073e9SAndroid Build Coastguard Worker            self.skipTest("neither reverse mode nor forward mode supported")
1355*da0073e9SAndroid Build Coastguard Worker
1356*da0073e9SAndroid Build Coastguard Worker        # note(crcrpar): without this, some unary functions fail, unlike inplace and/or complex.
1357*da0073e9SAndroid Build Coastguard Worker        if (
1358*da0073e9SAndroid Build Coastguard Worker            (not inplace)
1359*da0073e9SAndroid Build Coastguard Worker            and dtype == torch.float64
1360*da0073e9SAndroid Build Coastguard Worker            and op.name
1361*da0073e9SAndroid Build Coastguard Worker            in (
1362*da0073e9SAndroid Build Coastguard Worker                "_foreach_acos",
1363*da0073e9SAndroid Build Coastguard Worker                "_foreach_asin",
1364*da0073e9SAndroid Build Coastguard Worker                "_foreach_log10",
1365*da0073e9SAndroid Build Coastguard Worker                "_foreach_log1p",
1366*da0073e9SAndroid Build Coastguard Worker                "_foreach_log2",
1367*da0073e9SAndroid Build Coastguard Worker                "_foreach_log",
1368*da0073e9SAndroid Build Coastguard Worker                "_foreach_pow",
1369*da0073e9SAndroid Build Coastguard Worker                "_foreach_sqrt",
1370*da0073e9SAndroid Build Coastguard Worker            )
1371*da0073e9SAndroid Build Coastguard Worker        ):
1372*da0073e9SAndroid Build Coastguard Worker            value_range = {"low": 0.5, "high": 1.0}
1373*da0073e9SAndroid Build Coastguard Worker        else:
1374*da0073e9SAndroid Build Coastguard Worker            value_range = {}
1375*da0073e9SAndroid Build Coastguard Worker        for sample in op.sample_inputs(
1376*da0073e9SAndroid Build Coastguard Worker            device,
1377*da0073e9SAndroid Build Coastguard Worker            dtype,
1378*da0073e9SAndroid Build Coastguard Worker            requires_grad=True,
1379*da0073e9SAndroid Build Coastguard Worker            num_input_tensors=[5],
1380*da0073e9SAndroid Build Coastguard Worker            allow_higher_dtype_scalars=True,
1381*da0073e9SAndroid Build Coastguard Worker            **value_range,
1382*da0073e9SAndroid Build Coastguard Worker        ):
1383*da0073e9SAndroid Build Coastguard Worker            # Skip `_foreach_pow.ScalarAndTensor(Scalar, Tensor[])`
1384*da0073e9SAndroid Build Coastguard Worker            if op.name == "_foreach_pow" and isinstance(sample.input, Number):
1385*da0073e9SAndroid Build Coastguard Worker                continue
1386*da0073e9SAndroid Build Coastguard Worker
1387*da0073e9SAndroid Build Coastguard Worker            func = None
1388*da0073e9SAndroid Build Coastguard Worker            if inplace:
1389*da0073e9SAndroid Build Coastguard Worker                # Call `clone` to avoid inplace modifications likewise
1390*da0073e9SAndroid Build Coastguard Worker                # `torch.testing._internal.common_utils.TestGradients._get_safe_inplace`
1391*da0073e9SAndroid Build Coastguard Worker                def inplace_func(*tensorlist):
1392*da0073e9SAndroid Build Coastguard Worker                    kwargs = (
1393*da0073e9SAndroid Build Coastguard Worker                        {"alpha": sample.kwargs["alpha"]}
1394*da0073e9SAndroid Build Coastguard Worker                        if "alpha" in sample.kwargs
1395*da0073e9SAndroid Build Coastguard Worker                        else {}
1396*da0073e9SAndroid Build Coastguard Worker                    )
1397*da0073e9SAndroid Build Coastguard Worker                    op.inplace_variant(
1398*da0073e9SAndroid Build Coastguard Worker                        tuple(t.clone() for t in tensorlist), *sample.args, **kwargs
1399*da0073e9SAndroid Build Coastguard Worker                    )
1400*da0073e9SAndroid Build Coastguard Worker                    return tensorlist
1401*da0073e9SAndroid Build Coastguard Worker
1402*da0073e9SAndroid Build Coastguard Worker                func = inplace_func
1403*da0073e9SAndroid Build Coastguard Worker            else:
1404*da0073e9SAndroid Build Coastguard Worker
1405*da0073e9SAndroid Build Coastguard Worker                def outplace_func(*tensorlist):
1406*da0073e9SAndroid Build Coastguard Worker                    kwargs = (
1407*da0073e9SAndroid Build Coastguard Worker                        {"alpha": sample.kwargs["alpha"]}
1408*da0073e9SAndroid Build Coastguard Worker                        if "alpha" in sample.kwargs
1409*da0073e9SAndroid Build Coastguard Worker                        else {}
1410*da0073e9SAndroid Build Coastguard Worker                    )
1411*da0073e9SAndroid Build Coastguard Worker                    return op.method_variant(tensorlist, *sample.args, **kwargs)
1412*da0073e9SAndroid Build Coastguard Worker
1413*da0073e9SAndroid Build Coastguard Worker                func = outplace_func
1414*da0073e9SAndroid Build Coastguard Worker
1415*da0073e9SAndroid Build Coastguard Worker            working_sample, err_msg_pattern = check_autodiff_sample(
1416*da0073e9SAndroid Build Coastguard Worker                op, sample, dtype, inplace
1417*da0073e9SAndroid Build Coastguard Worker            )
1418*da0073e9SAndroid Build Coastguard Worker
1419*da0073e9SAndroid Build Coastguard Worker            def call_gradcheck():
1420*da0073e9SAndroid Build Coastguard Worker                gradcheck(
1421*da0073e9SAndroid Build Coastguard Worker                    func,
1422*da0073e9SAndroid Build Coastguard Worker                    sample.input,
1423*da0073e9SAndroid Build Coastguard Worker                    raise_exception=True,
1424*da0073e9SAndroid Build Coastguard Worker                    check_forward_ad=op.supports_forward_ad,
1425*da0073e9SAndroid Build Coastguard Worker                    check_batched_forward_grad=False,
1426*da0073e9SAndroid Build Coastguard Worker                    check_backward_ad=op.supports_autograd,
1427*da0073e9SAndroid Build Coastguard Worker                    check_batched_grad=False,
1428*da0073e9SAndroid Build Coastguard Worker                )
1429*da0073e9SAndroid Build Coastguard Worker
1430*da0073e9SAndroid Build Coastguard Worker            if not working_sample:
1431*da0073e9SAndroid Build Coastguard Worker                if not err_msg_pattern:
1432*da0073e9SAndroid Build Coastguard Worker                    # lhs of float64 and rhs of complex.
1433*da0073e9SAndroid Build Coastguard Worker                    continue
1434*da0073e9SAndroid Build Coastguard Worker                with self.assertRaisesRegex(RuntimeError, re.escape(err_msg_pattern)):
1435*da0073e9SAndroid Build Coastguard Worker                    call_gradcheck()
1436*da0073e9SAndroid Build Coastguard Worker                continue
1437*da0073e9SAndroid Build Coastguard Worker            call_gradcheck()
1438*da0073e9SAndroid Build Coastguard Worker
1439*da0073e9SAndroid Build Coastguard Worker            # Test per-tensor `grad_fn` behavior.
1440*da0073e9SAndroid Build Coastguard Worker            if inplace and op.supports_inplace_autograd:
1441*da0073e9SAndroid Build Coastguard Worker                # per-tensor `grad_fn` check.
1442*da0073e9SAndroid Build Coastguard Worker                hook_buffer = []
1443*da0073e9SAndroid Build Coastguard Worker
1444*da0073e9SAndroid Build Coastguard Worker                def get_grad_fn_hook(i):
1445*da0073e9SAndroid Build Coastguard Worker                    def hook(grad_inputs, grad_outputs) -> None:
1446*da0073e9SAndroid Build Coastguard Worker                        hook_buffer.append(i)
1447*da0073e9SAndroid Build Coastguard Worker
1448*da0073e9SAndroid Build Coastguard Worker                    return hook
1449*da0073e9SAndroid Build Coastguard Worker
1450*da0073e9SAndroid Build Coastguard Worker                _inputs = [t.clone().detach().requires_grad_() for t in sample.input]
1451*da0073e9SAndroid Build Coastguard Worker                inputs = [t.clone() for t in _inputs]
1452*da0073e9SAndroid Build Coastguard Worker                kwargs = (
1453*da0073e9SAndroid Build Coastguard Worker                    {"alpha": sample.kwargs["alpha"]}
1454*da0073e9SAndroid Build Coastguard Worker                    if "alpha" in sample.kwargs
1455*da0073e9SAndroid Build Coastguard Worker                    else {}
1456*da0073e9SAndroid Build Coastguard Worker                )
1457*da0073e9SAndroid Build Coastguard Worker                op.inplace_variant(inputs, *sample.args, **kwargs)
1458*da0073e9SAndroid Build Coastguard Worker
1459*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(len({t.grad_fn for t in inputs}), len(inputs))
1460*da0073e9SAndroid Build Coastguard Worker
1461*da0073e9SAndroid Build Coastguard Worker                for i, t in enumerate(inputs):
1462*da0073e9SAndroid Build Coastguard Worker                    t.grad_fn.register_hook(get_grad_fn_hook(i))
1463*da0073e9SAndroid Build Coastguard Worker
1464*da0073e9SAndroid Build Coastguard Worker                torch.autograd.grad(
1465*da0073e9SAndroid Build Coastguard Worker                    inputs[0],
1466*da0073e9SAndroid Build Coastguard Worker                    inputs=(_inputs[0],),
1467*da0073e9SAndroid Build Coastguard Worker                    grad_outputs=(torch.rand_like(inputs[0]),),
1468*da0073e9SAndroid Build Coastguard Worker                    retain_graph=True,
1469*da0073e9SAndroid Build Coastguard Worker                )
1470*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(hook_buffer, [0])
1471*da0073e9SAndroid Build Coastguard Worker                hook_buffer.clear()
1472*da0073e9SAndroid Build Coastguard Worker
1473*da0073e9SAndroid Build Coastguard Worker                # tensors have different shapes.
1474*da0073e9SAndroid Build Coastguard Worker                sum_of_cloned_tensors = torch.cat([t.view(-1) for t in inputs]).sum()
1475*da0073e9SAndroid Build Coastguard Worker                grad_output = torch.rand_like(sum_of_cloned_tensors)
1476*da0073e9SAndroid Build Coastguard Worker                torch.autograd.grad(
1477*da0073e9SAndroid Build Coastguard Worker                    sum_of_cloned_tensors,
1478*da0073e9SAndroid Build Coastguard Worker                    inputs=tuple(_inputs),
1479*da0073e9SAndroid Build Coastguard Worker                    grad_outputs=(grad_output,),
1480*da0073e9SAndroid Build Coastguard Worker                    retain_graph=False,
1481*da0073e9SAndroid Build Coastguard Worker                )
1482*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(hook_buffer, list(reversed(range(len(inputs)))))
1483*da0073e9SAndroid Build Coastguard Worker
1484*da0073e9SAndroid Build Coastguard Worker
1485*da0073e9SAndroid Build Coastguard Worker# TODO(crcrpar): Hide this inside torch/testing/_internal.
1486*da0073e9SAndroid Build Coastguard Worker# would end up adding another layer to `foreach_inputs_sample_func.__call__`
1487*da0073e9SAndroid Build Coastguard Worker# so that we can use this function as something like the first argument of `filter` function.
1488*da0073e9SAndroid Build Coastguard Worker# Even after moving this function to testing, I personally think it'd be better to check the error message.
1489*da0073e9SAndroid Build Coastguard Workerdef check_autodiff_sample(op, sample, dtype, is_inplace):
1490*da0073e9SAndroid Build Coastguard Worker    if op.name == "_foreach_abs" and is_inplace and dtype == torch.complex128:
1491*da0073e9SAndroid Build Coastguard Worker        return False, "In-place abs is not supported for complex tensors."
1492*da0073e9SAndroid Build Coastguard Worker    if op.name == "_foreach_sub" and (
1493*da0073e9SAndroid Build Coastguard Worker        (
1494*da0073e9SAndroid Build Coastguard Worker            isinstance(sample.args[-1], list)
1495*da0073e9SAndroid Build Coastguard Worker            and any(isinstance(a, bool) for a in sample.args[-1])
1496*da0073e9SAndroid Build Coastguard Worker        )
1497*da0073e9SAndroid Build Coastguard Worker        or isinstance(sample.args[-1], bool)
1498*da0073e9SAndroid Build Coastguard Worker    ):
1499*da0073e9SAndroid Build Coastguard Worker        return False, _BOOL_SUB_ERR_MSG
1500*da0073e9SAndroid Build Coastguard Worker    if op.name == "_foreach_norm" and (not is_inplace):
1501*da0073e9SAndroid Build Coastguard Worker        return (
1502*da0073e9SAndroid Build Coastguard Worker            False,
1503*da0073e9SAndroid Build Coastguard Worker            "Trying to set a forward gradient that has a different size than that of the original Tensor, "
1504*da0073e9SAndroid Build Coastguard Worker            "this is not supported. Tensor is of size [] while the given forward gradient is of size [1, 1].",
1505*da0073e9SAndroid Build Coastguard Worker        )
1506*da0073e9SAndroid Build Coastguard Worker    rhs_arg_has_complex_number = sample.args and (
1507*da0073e9SAndroid Build Coastguard Worker        (
1508*da0073e9SAndroid Build Coastguard Worker            isinstance(sample.args[-1], list)
1509*da0073e9SAndroid Build Coastguard Worker            and any(isinstance(a, complex) for a in sample.args[-1])
1510*da0073e9SAndroid Build Coastguard Worker        )
1511*da0073e9SAndroid Build Coastguard Worker        or (isinstance(sample.args[-1], complex))
1512*da0073e9SAndroid Build Coastguard Worker    )
1513*da0073e9SAndroid Build Coastguard Worker    if rhs_arg_has_complex_number and dtype == torch.float64:
1514*da0073e9SAndroid Build Coastguard Worker        if op.name in (
1515*da0073e9SAndroid Build Coastguard Worker            "_foreach_clamp_max",
1516*da0073e9SAndroid Build Coastguard Worker            "_foreach_clamp_min",
1517*da0073e9SAndroid Build Coastguard Worker            "_foreach_maximum",
1518*da0073e9SAndroid Build Coastguard Worker            "_foreach_minimum",
1519*da0073e9SAndroid Build Coastguard Worker        ):
1520*da0073e9SAndroid Build Coastguard Worker            return False, "clamp is not supported for complex types"
1521*da0073e9SAndroid Build Coastguard Worker        if op.name == "_foreach_lerp" and is_inplace:
1522*da0073e9SAndroid Build Coastguard Worker            return False, "value cannot be converted to type double without overflow"
1523*da0073e9SAndroid Build Coastguard Worker        if not is_inplace:
1524*da0073e9SAndroid Build Coastguard Worker            return False, ""
1525*da0073e9SAndroid Build Coastguard Worker        else:
1526*da0073e9SAndroid Build Coastguard Worker            if op.name == "_foreach_pow":
1527*da0073e9SAndroid Build Coastguard Worker                return False, "Found dtype Double but expected ComplexDouble"
1528*da0073e9SAndroid Build Coastguard Worker            if op.name in (
1529*da0073e9SAndroid Build Coastguard Worker                "_foreach_add",
1530*da0073e9SAndroid Build Coastguard Worker                "_foreach_sub",
1531*da0073e9SAndroid Build Coastguard Worker                "_foreach_mul",
1532*da0073e9SAndroid Build Coastguard Worker                "_foreach_div",
1533*da0073e9SAndroid Build Coastguard Worker            ):
1534*da0073e9SAndroid Build Coastguard Worker                return (
1535*da0073e9SAndroid Build Coastguard Worker                    False,
1536*da0073e9SAndroid Build Coastguard Worker                    "result type ComplexDouble can't be cast to the desired output type Double",
1537*da0073e9SAndroid Build Coastguard Worker                )
1538*da0073e9SAndroid Build Coastguard Worker    return True, ""
1539*da0073e9SAndroid Build Coastguard Worker
1540*da0073e9SAndroid Build Coastguard Worker
1541*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestForeach, globals())
1542*da0073e9SAndroid Build Coastguard Worker
1543*da0073e9SAndroid Build Coastguard Worker
1544*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
1545*da0073e9SAndroid Build Coastguard Worker    run_tests()
1546