xref: /aosp_15_r20/external/pytorch/test/functorch/common_utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Copyright (c) Facebook, Inc. and its affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import itertools
8import os
9import unittest
10from collections import namedtuple
11
12from functorch_additional_op_db import additional_op_db
13
14import torch
15import torch.utils._pytree as pytree
16from functorch import vmap
17from torch.testing._internal.autograd_function_db import autograd_function_db
18from torch.testing._internal.common_device_type import toleranceOverride
19from torch.testing._internal.common_methods_invocations import DecorateInfo, op_db
20from torch.testing._internal.common_modules import module_db
21from torch.testing._internal.custom_op_db import custom_op_db
22
23
24IS_FBCODE = os.getenv("FUNCTORCH_TEST_FBCODE") == "1"
25
26
27def loop(op, in_dims, out_dim, batch_size, *batched_args, **kwarg_values):
28    outs = []
29    out_spec = None
30    for idx in range(batch_size):
31        flat_args, args_spec = pytree.tree_flatten(batched_args)
32        flat_dims, dims_spec = pytree.tree_flatten(in_dims)
33        assert args_spec == dims_spec
34        new_args = [
35            a.select(in_dim, idx) if in_dim is not None else a
36            for a, in_dim in zip(flat_args, flat_dims)
37        ]
38        out = op(*pytree.tree_unflatten(new_args, args_spec), **kwarg_values)
39        flat_out, out_spec = pytree.tree_flatten(out)
40        outs.append(flat_out)
41
42    # use the same out_dim for all outputs
43    if isinstance(out_dim, int):
44        flat_out_dim = [out_dim for _ in flat_out]
45    else:
46        flat_out_dim, _ = pytree.tree_flatten(out_dim)
47
48    outs = zip(*outs)
49
50    result = []
51    for i, out_lst in enumerate(outs):
52        if flat_out_dim[i] is not None:
53            if not all(isinstance(x, torch.Tensor) for x in out_lst):
54                raise ValueError(
55                    f"vmap `{op}` must only return "
56                    "Tensors. Did you mean to set out_dims= to None for output?"
57                )
58            result.append(torch.stack(out_lst))
59        else:
60            # not batched over, result should be the same for all batches
61            result.append(out_lst[0])
62    return pytree.tree_unflatten(result, out_spec)
63
64
65# Like loop helper function but for 2 levels of vmap. If we need more levels than this, probably possible
66# to generalize the loops function but it seemed too complicated for this
67def loop2(
68    op,
69    in_dims1,
70    in_dims2,
71    out_dim1,
72    out_dim2,
73    batch_size1,
74    batch_size2,
75    *batched_args,
76    **kwarg_values,
77):
78    outs = []
79    flat_args, args_spec = pytree.tree_flatten(batched_args)
80    flat_dims1, dims_spec1 = pytree.tree_flatten(in_dims1)
81    flat_dims2, dims_spec2 = pytree.tree_flatten(in_dims2)
82    assert args_spec == dims_spec1
83    assert args_spec == dims_spec2
84    assert len(flat_dims1) == len(flat_dims2)
85    for idx1 in range(batch_size1):
86        out_split = []
87        arg_split = [
88            a.select(in_dim1, idx1) if in_dim1 is not None else a
89            for a, in_dim1 in zip(flat_args, flat_dims1)
90        ]
91        for idx2 in range(batch_size2):
92            new_args = [
93                a.select(in_dim, idx2) if in_dim is not None else a
94                for a, in_dim in zip(arg_split, flat_dims2)
95            ]
96            out = op(*pytree.tree_unflatten(new_args, args_spec), **kwarg_values)
97            out_split.append(out)
98        outs.append(out_split)
99
100    loop_out = []
101    for out_split in outs:
102        if isinstance(out_split[0], torch.Tensor):
103            loop_out.append(torch.stack(out_split, out_dim1))
104        else:
105            new_out = []
106            for idx in range(len(out_split[0])):
107                new_out.append(torch.stack([i[idx] for i in out_split], out_dim1))
108            loop_out.append(new_out)
109
110    new_out = []
111    if isinstance(loop_out, torch.Tensor):
112        new_out = torch.stack(loop_out, out_dim2)
113    else:
114        for idx in range(len(loop_out[0])):
115            new_out.append(torch.stack([i[idx] for i in loop_out], out_dim2))
116    return new_out
117
118
119def is_valid_inplace_sample_input(sample_input, op, inplace_variant):
120    if inplace_variant is None:
121        return False
122    if sample_input.broadcasts_input:
123        return False
124    if not isinstance(sample_input.input, torch.Tensor):
125        return False
126
127    # Check if input's dtype matches the output's dtype
128    args = (sample_input.input,) + sample_input.args
129    kwargs = sample_input.kwargs
130    output_dtype = op(*args, **kwargs).dtype
131    return sample_input.input.dtype == output_dtype
132
133
134# This is kind of dangerous, please think carefully before using it.
135# Known risks:
136# - the return better not be mutated so it's best to return immutable types
137# (e.g. prefer tuples to list)
138# - Don't hash tensors in a global context, that'll keep them around forever
139def memoize(fn):
140    memo = {}
141
142    def wrapped(*args):
143        if args not in memo:
144            memo[args] = fn(*args)
145        return memo[args]
146
147    return wrapped
148
149
150# NB: This is O(2 ** num_tensors).
151# num_tensors ranges from 1 to 10, with 2-4 being most common.
152# Try not to extravagate it if you're modifying it.
153@memoize
154def get_bdim_choices(num_tensors):
155    choices = []
156
157    # full of zeros
158    choices.append((0,) * num_tensors)
159
160    # All permutations of (-1, None)
161    options = (-1, None)
162    choices.extend(itertools.product(options, repeat=num_tensors))
163
164    assert choices[-1] == (None,) * num_tensors
165    return tuple(choices[:-1])
166
167
168# NB: This is O(2 ** num_tensors).
169# num_tensors ranges from 1 to 10, with 2-4 being most common.
170# Try not to extravagate it if you're modifying it.
171def get_bdim_choices_batch_norm(
172    num_tensors, _, running_mean=None, running_var=None, *args
173):
174    choices = []
175    options = (-1, None)
176
177    # instance norm turns these into unbatched 0 tensors, so we cannot batch the input if either is not specified
178    if running_mean is None or running_var is None:
179        choices.append((None,) + (0,) * (num_tensors - 1))
180        for choice in itertools.product(options, repeat=num_tensors - 1):
181            choices.append((None,) + choice)
182
183    else:
184        # running_mean and running_var are specified as tensors. Batch norm doesn't work if the input is batched but
185        # running_mean/var are unbatched, so this tests all other cases
186        choices.append((0,) * num_tensors)
187        for choice in itertools.product(options, repeat=num_tensors):
188            input_bdim = choice[0]
189            running_mean_bdim = choice[1]
190            running_var_bdim = choice[2]
191            if input_bdim and (not running_mean_bdim or not running_var_bdim):
192                continue
193            choices.append(choice)
194
195    assert choices[-1] == (None,) * num_tensors
196    return tuple(choices[:-1])
197
198
199def add_batch_dim(arg, bdim, batch_size=3):
200    assert bdim == 0 or bdim == -1
201    assert isinstance(arg, torch.Tensor)
202    if bdim == 0:
203        shape = [1] * len(arg.shape)
204        shape.insert(bdim, batch_size)
205        return (arg.repeat(shape), bdim)
206    if bdim == -1:
207        arg = arg.unsqueeze(-1).expand(*arg.shape, batch_size).contiguous()
208        return (arg, bdim)
209
210
211def construct_in_dims(bdim_choice_for_tensors, is_tensors):
212    result = []
213    bdim = iter(bdim_choice_for_tensors)
214    for is_tensor in is_tensors:
215        if not is_tensor:
216            result.append(None)
217            continue
218        result.append(next(bdim))
219    return tuple(result)
220
221
222def is_batch_norm_training(op_name, kwarg_values):
223    batch_norm_fns = (
224        "nn.functional.batch_norm",
225        "nn.functional.instance_norm",
226    )  # instance norm calls batch norm
227    if op_name not in batch_norm_fns:
228        return False
229
230    # batch norm and instance norm require the value to be a plain bool
231    default_training = (
232        op_name == "nn.functional.instance_norm"
233    )  # instance norm defaults to training, batch norm doesn't
234    is_training = tuple(
235        arg for arg in tuple(kwarg_values.values()) if isinstance(arg, bool)
236    )
237    if len(is_training) == 0:
238        return default_training
239    else:
240        assert len(is_training) == 1
241        return is_training[0]
242
243
244def generate_vmap_inputs(
245    arg_values, kwarg_values, is_batch_norm_and_training=False, batch_size=2
246):
247    flat_args, arg_spec = pytree.tree_flatten(tuple(arg_values))
248    is_tensors = [isinstance(a, torch.Tensor) for a in flat_args]
249    num_tensors = sum(is_tensors)
250    # For Batch Norm, if there's only an input, we can't
251    # batch it since running_mean/var will be seen as unbatched tensors
252    if num_tensors == 1 and is_batch_norm_and_training:
253        return
254    bdim_choices = (
255        get_bdim_choices_batch_norm(num_tensors, *arg_values)
256        if is_batch_norm_and_training
257        else get_bdim_choices(num_tensors)
258    )
259
260    @memoize
261    def get_batched_arg(arg, bdim):
262        assert isinstance(arg, torch.Tensor)
263        assert bdim is not None
264        result, _ = add_batch_dim(arg, bdim, batch_size)
265        return result
266
267    for bdim_choice in bdim_choices:
268        flat_in_dims = construct_in_dims(bdim_choice, is_tensors)
269
270        flat_batched_args = tuple(
271            arg if in_dim is None else get_batched_arg(arg, in_dim)
272            for arg, in_dim in zip(flat_args, flat_in_dims)
273        )
274        batched_args = pytree.tree_unflatten(flat_batched_args, arg_spec)
275        in_dims = pytree.tree_unflatten(flat_in_dims, arg_spec)
276        yield batched_args, in_dims, kwarg_values
277
278
279def clone_if_tensor(x):
280    if isinstance(x, torch.Tensor):
281        return x.clone()
282    return x
283
284
285# Helper function to compare output of `vmap` against the
286# `for-loop` version.
287def _compute_quantities_for_vmap_test(
288    op,
289    orig_batched_args,
290    orig_kwarg_values,
291    in_dims,
292    out_dim,
293    batch_size,
294    compute_loop_out=True,
295    clone_inputs=False,
296):
297    def maybe_clone_inputs():
298        if clone_inputs:
299            batched_args = pytree.tree_map(clone_if_tensor, orig_batched_args)
300            kwarg_values = pytree.tree_map(clone_if_tensor, orig_kwarg_values)
301            return batched_args, kwarg_values
302        return orig_batched_args, orig_kwarg_values
303
304    batched_args, kwarg_values = maybe_clone_inputs()
305
306    if compute_loop_out:
307        loop_out = loop(op, in_dims, out_dim, batch_size, *batched_args, **kwarg_values)
308    else:
309        loop_out = None
310
311    # Used for debugging the resulting operations
312    # from functorch import make_fx
313    # def f(a):
314    #     return op(a)
315    # t = make_fx(vmap(f, in_dims=in_dims, out_dims=out_dim))(*batched_args, **kwarg_values)
316    # print(in_dims, [arg.shape for arg in batched_args], kwarg_values)
317    batched_args, kwarg_values = maybe_clone_inputs()
318    batched_out = vmap(op, in_dims=in_dims, out_dims=out_dim)(
319        *batched_args, **kwarg_values
320    )
321
322    # Tests case where we dispatch to a batching rule with no bdims
323    # This should be handled by autogenerated plumbing. For vmap support
324    # added via a manual plumbing you may need to handle this specially.
325    def add_bdim_if_tensor(x):
326        if isinstance(x, torch.Tensor):
327            return x.unsqueeze(1)
328        return x
329
330    def f(dummy, *args, **kwargs):
331        return op(*args, **kwargs)
332
333    dummy = torch.ones(batch_size, 1)
334    vmapvmap_expected = pytree.tree_map(add_bdim_if_tensor, batched_out)
335
336    inner_in_dims = (0,) + pytree.tree_map(lambda x: None, in_dims)
337    outer_in_dims = (0,) + in_dims
338    batched_args, kwarg_values = maybe_clone_inputs()
339    vmapvmap_output = vmap(
340        vmap(f, inner_in_dims, out_dims=out_dim), outer_in_dims, out_dims=out_dim
341    )(dummy, *batched_args, **kwarg_values)
342
343    yield (batched_out, loop_out, vmapvmap_output, vmapvmap_expected)
344
345
346# Function with more friendly return types
347# compared to `_compute_quantities_for_vmap_test`
348def compute_quantities_for_vmap_test(
349    op,
350    orig_batched_args,
351    orig_kwarg_values,
352    in_dims,
353    out_dim=0,
354    batch_size=2,
355    compute_loop_out=True,
356    clone_inputs=False,
357):
358    for quantities in _compute_quantities_for_vmap_test(
359        op,
360        orig_batched_args,
361        orig_kwarg_values,
362        in_dims,
363        out_dim,
364        batch_size,
365        compute_loop_out,
366        clone_inputs,
367    ):
368        yield (quantities[0], quantities[1])
369        yield (quantities[2], quantities[3])
370
371
372def get_fallback_and_vmap_exhaustive(
373    op,
374    arg_values,
375    kwarg_values,
376    is_batch_norm_and_training=False,
377    compute_loop_out=True,
378):
379    out_dim = 0
380    batch_size = 2
381
382    def make_batched(t):
383        if isinstance(t, torch.Tensor):
384            shape = list(t.shape)
385            shape.insert(out_dim, batch_size)
386            return t.expand(*shape)
387        return t
388
389    # Inputs generated by `generate_vmap_inputs` just copy/expand the unbatched inputs
390    # over the batched dimension. Thus we can compute the expected value once and just
391    # expand it based on the `out_dim` and `batch_size`.
392    expected_unbatched = op(*arg_values, **kwarg_values)
393    expected_batched = pytree.tree_map(make_batched, expected_unbatched)
394    generator = generate_vmap_inputs(
395        arg_values, kwarg_values, is_batch_norm_and_training
396    )
397    for batched_args, in_dims, kwarg_values in generator:
398        for quantities in _compute_quantities_for_vmap_test(
399            op,
400            batched_args,
401            kwarg_values,
402            in_dims,
403            out_dim,
404            batch_size,
405            compute_loop_out=False,
406        ):
407            assert quantities[1] is None
408            yield (quantities[0], expected_batched)
409            yield (quantities[2], quantities[3])
410
411
412def opinfo_in_dict(opinfo, d):
413    return (opinfo.name in d) or (f"{opinfo.name}.{opinfo.variant_test_name}" in d)
414
415
416DecorateMeta = namedtuple(
417    "DecorateMeta",
418    [
419        "op_name",
420        "variant_name",
421        "decorator",
422        "device_type",
423        "dtypes",
424    ],
425)
426
427
428def decorate(
429    op_name, variant_name="", *, decorator=None, device_type=None, dtypes=None
430):
431    assert decorator is not None
432    return DecorateMeta(
433        op_name=op_name,
434        variant_name=variant_name,
435        decorator=decorator,
436        device_type=device_type,
437        dtypes=dtypes,
438    )
439
440
441def xfail(op_name, variant_name="", *, device_type=None, dtypes=None):
442    return decorate(
443        op_name=op_name,
444        variant_name=variant_name,
445        decorator=unittest.expectedFailure,
446        device_type=device_type,
447        dtypes=dtypes,
448    )
449
450
451def skip(op_name, variant_name="", *, device_type=None, dtypes=None):
452    return decorate(
453        op_name=op_name,
454        variant_name=variant_name,
455        decorator=unittest.skip("Skipped!"),
456        device_type=device_type,
457        dtypes=dtypes,
458    )
459
460
461def skipOps(test_case_name, base_test_name, to_skip):
462    all_opinfos = op_db + additional_op_db + autograd_function_db + custom_op_db
463    for decorate_meta in to_skip:
464        matching_opinfos = [
465            o
466            for o in all_opinfos
467            if o.name == decorate_meta.op_name
468            and o.variant_test_name == decorate_meta.variant_name
469        ]
470        assert len(matching_opinfos) > 0, f"Couldn't find OpInfo for {decorate_meta}"
471        assert len(matching_opinfos) == 1, (
472            "OpInfos should be uniquely determined by their (name, variant_name). "
473            f"Got more than one result for ({decorate_meta.op_name}, {decorate_meta.variant_name})"
474        )
475        opinfo = matching_opinfos[0]
476        decorators = list(opinfo.decorators)
477        new_decorator = DecorateInfo(
478            decorate_meta.decorator,
479            test_case_name,
480            base_test_name,
481            device_type=decorate_meta.device_type,
482            dtypes=decorate_meta.dtypes,
483        )
484        decorators.append(new_decorator)
485        opinfo.decorators = tuple(decorators)
486
487    # This decorator doesn't modify fn in any way
488    def wrapped(fn):
489        return fn
490
491    return wrapped
492
493
494def decorateForModules(decorator, module_classes, device_type=None, dtypes=None):
495    # This decorator doesn't modify fn in any way
496    def wrapped(
497        fn,
498        module_classes=module_classes,
499        decorator=decorator,
500        device_type=device_type,
501        dtypes=dtypes,
502    ):
503        name_parts = fn.__qualname__.split(".")
504        assert (
505            len(name_parts) == 2
506        ), "Decorator only applies to a test function of a test class"
507        test_case_name, base_test_name = name_parts
508        for module_cls in module_classes:
509            matching_module_infos = [m for m in module_db if m.module_cls == module_cls]
510            assert (
511                len(matching_module_infos) == 1
512            ), f"Couldn't find single ModuleInfo for {module_cls}"
513            module_info = matching_module_infos[0]
514            decorators = list(module_info.decorators)
515            new_decorator = DecorateInfo(
516                decorator,
517                test_case_name,
518                base_test_name,
519                device_type=device_type,
520                dtypes=dtypes,
521            )
522            decorators.append(new_decorator)
523            module_info.decorators = tuple(decorators)
524        return fn
525
526    return wrapped
527
528
529def expectedFailureIf(condition):
530    def decorator(fn):
531        if condition:
532            return unittest.expectedFailure(fn)
533        return fn
534
535    return decorator
536
537
538def tol2(op_name, variant_name, override_dct, *, device_type=None):
539    return (op_name, variant_name, override_dct, device_type)
540
541
542def tol1(op_name, override_dct, *, device_type=None):
543    return tol2(op_name, "", override_dct, device_type=device_type)
544
545
546def opsToleranceOverride(test_case_name, base_test_name, overrides):
547    all_opinfos = op_db + additional_op_db
548    for override in overrides:
549        op_name, variant_name, override, device_type = override
550        matching_opinfos = [
551            o
552            for o in all_opinfos
553            if o.name == op_name and o.variant_test_name == variant_name
554        ]
555        assert len(matching_opinfos) == 1, f"Couldn't find OpInfo for {override}"
556        opinfo = matching_opinfos[0]
557        decorators = list(opinfo.decorators)
558        decorators.append(
559            DecorateInfo(
560                toleranceOverride(override),
561                test_case_name,
562                base_test_name,
563                device_type=device_type,
564            )
565        )
566        opinfo.decorators = tuple(decorators)
567
568    # This decorator doesn't modify fn in any way
569    def wrapped(fn):
570        return fn
571
572    return wrapped
573
574
575class DisableVmapFallback:
576    def __enter__(self):
577        self.prev_state = torch._C._functorch._is_vmap_fallback_enabled()
578        torch._C._functorch._set_vmap_fallback_enabled(False)
579
580    def __exit__(self, *ignored):
581        torch._C._functorch._set_vmap_fallback_enabled(self.prev_state)
582
583
584def check_vmap_fallback(test_case, thunk, opinfo, dry_run=False):
585    try:
586        with DisableVmapFallback():
587            thunk()
588    except Exception:
589        if not dry_run:
590            raise
591        if opinfo.variant_test_name:
592            print(f"xfail('{opinfo.name}', '{opinfo.variant_test_name}'),")
593        else:
594            print(f"xfail('{opinfo.name}'),")
595