xref: /aosp_15_r20/external/pytorch/torch/testing/_internal/composite_compliance.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: ignore-errors
2
3import torch
4from torch import Tensor
5import itertools
6
7from torch.utils._python_dispatch import TorchDispatchMode
8from torch.utils._pytree import tree_map, tree_flatten, tree_unflatten
9from torch.utils import _pytree as pytree
10from functools import partial
11from torch.utils._mode_utils import no_dispatch, all_same_mode
12import torch.autograd.forward_ad as fwAD
13from typing import Callable
14import re
15
16
17def check_attr_consistency(wrapper_tensor, metadata_name, metadata_accessor):
18    elem = wrapper_tensor.elem
19    metadata_wrapper_tensor = metadata_accessor(wrapper_tensor)
20    metadata_elem = metadata_accessor(elem)
21    if metadata_wrapper_tensor == metadata_elem:
22        return
23    raise RuntimeError(
24        f"This operator is not Composite Compliant: the "
25        f"{metadata_name} of the tensor was modified directly without "
26        f"going through the PyTorch dispatcher.")
27
28def check_metadata_consistency(wrapper_tensor, CCT):
29    # CCT: CompositeCompliantTensor class which is generated using generate_cct
30    if not isinstance(wrapper_tensor, CCT):
31        return
32    things_to_check = {
33        'shape': Tensor.size,
34        'dtype': lambda x: x.dtype,
35        'device': lambda x: x.device,
36        'numel': Tensor.numel,
37        'stride': Tensor.stride,
38        'storage_offset': Tensor.storage_offset,
39    }
40    for metadata_name, metadata_accessor in things_to_check.items():
41        check_attr_consistency(wrapper_tensor, metadata_name, metadata_accessor)
42
43def is_view_fn(func):
44    return func.overloadpacket.__name__ in {
45        'as_strided',
46        'detach',
47        'diagonal',
48        'expand',
49        'expand_as',
50        'movedim',
51        'narrow',
52        'permute',
53        'select',
54        'squeeze',
55        'transpose',
56        't',
57        'real',
58        'imag',
59        'view_as_real',
60        'view_as_complex',
61        'unflatten',
62        'unfold',
63        'unsqueeze',
64        'view',
65        'view_as',
66        'unbind',
67        'split',
68        'split_with_sizes',
69        'vsplit',
70        'hsplit',
71        'tensor_split',
72        'chunk',
73        'swapaxes',
74        'slice',
75        '_reshape_alias',
76        '_unsafe_view',
77        '_conj',
78        'alias',
79    }
80
81# manually populated from native_functions that have inplace_view: True.
82# In the future we will probably be able to grab that list directly
83def is_inplace_view_fn(func):
84    return func.overloadpacket.__name__ in {
85        'as_strided_',
86        'detach_',
87        'squeeze_',
88        'swapaxes_',
89        'swapdims_',
90        't_',
91        'transpose_',
92        'unsqueeze_',
93    }
94
95
96# Introspection please save us
97def is_inplace(func):
98    name = func.overloadpacket.__name__
99    if re.match('__i.+__', name):
100        return True
101    if re.match('__.+__', name):
102        return False
103    return name[-1] == '_'
104
105
106def generate_cct_and_mode(autograd_view_consistency=True):
107    # This function returns a new class CompositeCompliantTensor
108    # The two arguments control the behaviour described below.
109
110    # autograd_view_consistency:
111    #   If True, alias result using `set_` if func returns a view
112    #   (See Note [Alias Result]).
113    #   Since Forward AD doesn't work with `set_`
114    #   we disable it by setting alias to False.
115
116    class CompositeCompliantTensor(torch.Tensor):
117        elem: torch.Tensor
118
119        __slots__ = ['elem']
120
121        @staticmethod
122        def __new__(cls, elem, mode, *args, **kwargs):
123            assert type(elem) is not cls, \
124                "Wrapping a CompositeCompliantTensor in a CompositeCompliantTensor is not supported"
125
126            # The storage of CompositeCompliantTensor should never be used directly
127            # by a Composite operation; if the Composite
128            # operator attempts to read from the storage without dispatching then it'll
129            # raise a RuntimeError due to it being a meta storage.
130            r = torch.Tensor._make_wrapper_subclass(  # type: ignore[attr-defined]
131                cls, elem.size(),
132                dtype=elem.dtype, layout=elem.layout,
133                device=elem.device, requires_grad=elem.requires_grad,
134                strides=elem.stride(), storage_offset=elem.storage_offset())
135
136            if elem.requires_grad:
137                # CompositeCompliantTensor steals the "requires_grad"-ness.
138                # Why a new copy of `elem`? Because sometimes OpInfo shares inputs between tests...
139                tmp = torch.empty_strided(elem.shape, elem.stride(), dtype=elem.dtype,
140                                          device=elem.device, layout=elem.layout,
141                                          requires_grad=False)
142                tmp.copy_(elem.detach())
143                r.elem = tmp
144            else:
145                r.elem = elem
146
147            assert r.stride() == r.elem.stride()
148
149            # Propagate conjugate bits to the wrapper tensor
150            # Ref: https://github.com/albanD/subclass_zoo/issues/24
151            # Ref: https://github.com/albanD/subclass_zoo/issues/21
152            torch._C._set_conj(r, r.elem.is_conj())
153            torch._C._set_neg(r, r.elem.is_neg())
154
155            r.mode = mode
156            return r
157
158        def __repr__(self):
159            return f"CompositeCompliantTensor({self.elem})"
160
161        @classmethod
162        def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
163            all_args = pytree.arg_tree_leaves(*args, **(kwargs or {}))
164            modes = tuple(e.mode for e in all_args if isinstance(e, CompositeCompliantTensor))
165            if not all_same_mode(modes):
166                raise RuntimeError("Multiple CompositeCompliantTensorModes NYI")
167            with modes[0]:
168                return func(*args, **kwargs)
169
170    class CompositeCompliantTensorMode(TorchDispatchMode):
171        def __torch_dispatch__(self, func, types, args=(), kwargs=None):
172            def unwrap(e):
173                return e.elem if isinstance(e, CompositeCompliantTensor) else e
174
175            def wrap(e):
176                return CompositeCompliantTensor(e, self) if isinstance(e, torch.Tensor) else e
177
178            if func == torch.ops.aten._local_scalar_dense.default:
179                raise RuntimeError(
180                    ".item() is not allowed to be called inside of composite "
181                    "functions in the PyTorch library because not all backends "
182                    "and/or Tensor subclasses (e.g. vmap, ProxyTensor) support them.")
183
184            if func.overloadpacket.__name__ in ('set_', 'resize_'):
185                raise RuntimeError(
186                    f"{func.__name__} is not allowed to be called inside of "
187                    f"Composite operators.")
188
189            if is_inplace(func):
190                # NB: We are making an assumption that if the function is in-place,
191                # then the first argument is being written to. Introspection please save us!
192                mutated_argument = args[0]
193                if not isinstance(mutated_argument, CompositeCompliantTensor) and \
194                        any(isinstance(a, CompositeCompliantTensor) for a in args[1:]):
195                    raise RuntimeError(
196                        'Not composite compliant: performing in-place operation '
197                        f'{func.__name__} where the Tensor being written to is '
198                        'regular Tensor but the other tensors are Tensor Subclasses. '
199                        'Please try to avoid this in-place operation.')
200
201            unwrapped_args = tree_map(unwrap, args)
202            unwrapped_kwargs = tree_map(unwrap, kwargs)
203            unwrapped_rs = func(*unwrapped_args, **unwrapped_kwargs)
204            rs = tree_map(wrap, unwrapped_rs)
205
206            if is_view_fn(func) and autograd_view_consistency:
207                # Note [Alias Result]
208                # Autograd asserts that for B = A.view_fn(...), B and A's storages
209                # are the same. Here we try to make B alias A to avoid those asserts.
210                # See https://github.com/pytorch/pytorch/issues/65339 for more information
211                # about the issue.
212                with no_dispatch():
213                    # Idea: this is a weird way of getting a storage that aliases the input.
214                    # This is a workaround for #65339.
215                    # 1. under no_dispatch, all of the wrapper tensors look like regular
216                    #    tensors with special storage (the storage is nullptr and
217                    #    advertises CPU/CUDA device.
218                    # 2. we run func, which ends up running the view operation
219                    # 3. All view operations reuse the input's storage and return
220                    #    result Tensor(s) with new sizes/strides/offset that alias
221                    #    the input.
222                    # 4. we set the storage (and sizes/strides/offset) of the wrapper
223                    #    tensor results to be that of the tensors that alias the input
224                    result = func(*args, **kwargs)
225                    if isinstance(result, (tuple, list)):
226                        for a, b in zip(rs, result):
227                            a.set_(b)
228                    else:
229                        rs.set_(result)
230
231            # Some operations are allowed to in-place modify the metadata of the
232            # inputs. The only ones are the "inplace view functions"; when we
233            # run into these, we manually modify the metadata of the input.
234            with no_dispatch():
235                if is_inplace_view_fn(func):
236                    func(*args, **kwargs)
237
238            # For each CompositeCompliantTensor t, we check that t and t.elem
239            # have consistent metadata. If they don't have consistent metadata,
240            # that means the operator did something fishy.
241            check = partial(check_metadata_consistency, CCT=CompositeCompliantTensor)
242            pytree.tree_map_(check, args)
243            pytree.tree_map_(check, kwargs)
244            pytree.tree_map_(check, rs)
245            return rs
246
247    return CompositeCompliantTensor, CompositeCompliantTensorMode()
248
249def is_tensorlist(lst):
250    if not isinstance(lst, list) and not isinstance(lst, tuple):
251        return False
252    if len(lst) == 0:
253        return False
254    all_tensors = all(isinstance(elt, torch.Tensor) for elt in lst)
255    if all_tensors:
256        return True
257    exists_one_tensor = all(isinstance(elt, torch.Tensor) for elt in lst)
258    if exists_one_tensor:
259        raise RuntimeError('This test assumes that PyTorch APIs cannot take '
260                           'mixed lists of Tensor and other things')
261    return False
262
263
264def maybe_map(fn, should_map, arg):
265    return fn(arg) if should_map else arg
266
267
268def wrap(arg, CCT, cct_mode):
269    # CCT: CompositeCompliantTensor class which is generated using generate_cct_and_mode
270    if isinstance(arg, torch.Tensor):
271        return CCT(arg, cct_mode)
272    if is_tensorlist(arg):
273        return [CCT(a, cct_mode) for a in arg]
274    raise RuntimeError("wrap assumes that the input can be wrapped")
275
276
277# Given a list of flat arguments, some of which may be Tensors, return all
278# possible ways some of the arguments could be CompositeCompliantTensors (CCT).
279# For example, given Tensors A, B, C and flat_args = [A, 1, B],
280# We would return the following 4 options:
281# [CCT(A), 1, CCT(B)]
282# [CCT(A), 1, B]
283# [A, 1, CCT(B)]
284# [A, 1, B]
285# NB: Yes, this is exponential. No, we don't care too much because PyTorch ops
286# don't accept that many input Tensors.
287def generate_subclass_choices(flat_args, CCT, cct_mode):
288    # CCT: CompositeCompliantTensor class which is generated using generate_cct_and_mode
289    is_tensor_likes = [isinstance(arg, torch.Tensor) or is_tensorlist(arg) for arg in flat_args]
290    subclass_options = [[False, True] if is_tensor_like else [False] for is_tensor_like in is_tensor_likes]
291
292    for which_args_are_wrapped in itertools.product(*subclass_options):
293
294        result = [maybe_map(partial(wrap, CCT=CCT, cct_mode=cct_mode), should_wrap_arg, arg)
295                  for should_wrap_arg, arg in zip(which_args_are_wrapped, flat_args)]
296        yield result, which_args_are_wrapped
297
298
299# For an operation f(*args, **kwargs), each Tensor argument may either be
300# a regular Tensor or a Tensor Subclass. This iterator iterates through
301# all of those options.
302def generate_subclass_choices_args_kwargs(args, kwargs, CCT, cct_mode):
303    # CCT: CompositeCompliantTensor class which is generated using generate_cct_and_mode
304    flat_kwargs, spec = tree_flatten(kwargs)
305    flat_args_kwargs = list(args) + list(flat_kwargs)
306    for choice, debug_metadata in generate_subclass_choices(flat_args_kwargs, CCT, cct_mode):
307        new_args = choice[:len(args)]
308        new_kwargs = tree_unflatten(choice[len(args):], spec)
309        which_args_are_wrapped = debug_metadata[:len(args)]
310        which_kwargs_are_wrapped = tree_unflatten(debug_metadata[len(args):], spec)
311        yield new_args, new_kwargs, which_args_are_wrapped, which_kwargs_are_wrapped
312
313
314def raise_composite_compliance_error(err, additional_info=''):
315    raise RuntimeError(
316        "Composite compliance check failed with "
317        "the above error.\n"
318        f"{additional_info}"
319        "If you are adding an OpInfo of an "
320        "existing operator, please feel free to skip this test "
321        "because the problem was pre-existing and file an issue. "
322        "Otherwise, if you added a new operator, please read "
323        "through the Composite Compliance section in "
324        "aten/src/ATen/native/README.md for how to resolve this. "
325    ) from err
326
327
328# This test checks ALL possible permutations of calling `op` with arguments
329# that are individually either a regular Tensor or a Tensor subclass.
330#
331# The general strategy is to wrap some Tensor args and kwargs in
332# CompositeCompliantTensor wrappers and call the operation.
333
334# If some composite operation does any non-compliant behavior,
335# CompositeCompliantTensor will raise an error.
336def check_all_permutations(op, args, kwargs, assert_equal_fn):
337    CCT, cct_mode = generate_cct_and_mode()
338    expected = op(*args, **kwargs)
339    for choice in generate_subclass_choices_args_kwargs(args, kwargs, CCT, cct_mode):
340        new_args, new_kwargs, which_args_are_wrapped, which_kwargs_are_wrapped = choice
341
342        try:
343            actual = op(*new_args, **new_kwargs)
344        # NOTE: [What errors are Composite Compliance trying to catch?]
345        #
346        # There's two things we want to catch:
347        # - errors that would raise within the torch_dispatch impl
348        # - data_ptr accesses
349        # The first is easy to filter for (we could make the error a different
350        # error class), the second is always going to be a RuntimeError due to
351        # how it is implemented (if you try to access the data_ptr of thex
352        # wrapper Tensor, it raises you some internal RuntimeError).
353        #
354        # So the most general thing to catch here was RuntimeError. If you
355        # are here and debugging why your test failed, it's plausible that
356        # the operator itself is broken and that there are other tests failing.
357        except RuntimeError as err:
358            raise_composite_compliance_error(
359                err,
360                f"- wrapped_args: {which_args_are_wrapped}\n"
361                f"- wrapped_kwargs: {which_kwargs_are_wrapped}\n"
362            )
363
364        def unwrap(e):
365            return e.elem if isinstance(e, CCT) else e
366
367        assert_equal_fn(tree_map(unwrap, actual), expected)
368
369# Checks via the usage of torch dispatch mode certain anti-patterns that
370# are not composite compliant.
371#
372# In particular, the anti-pattern we are trying to prevent is a user
373# creating an empty tensor and then resize_-ing it. Torch Dispatch Mode helps
374# here because all factory functions will create tensors that are
375# CompositeCompliantTensor.
376#
377# The general strategy is to wrap all Tensor args and kwargs in
378# CompositeCompliantTensor wrappers. If an operator that is
379# Composite does any non-compliant behavior,
380# CompositeCompliantTensor will raise an error.
381def check_with_mode(op, args, kwargs, assert_equal_fn):
382    CCT, cct_mode = generate_cct_and_mode()
383
384    def wrap(e):
385        return CCT(e, cct_mode) if isinstance(e, torch.Tensor) else e
386
387    expected = op(*args, **kwargs)
388
389    args = tree_map(wrap, args)
390    kwargs = tree_map(wrap, kwargs)
391    try:
392        with cct_mode:
393            actual = op(*args, **kwargs)
394    # see NOTE: [What errors are Composite Compliance trying to catch?]
395    except RuntimeError as err:
396        raise_composite_compliance_error(err)
397
398    def unwrap(e):
399        return e.elem if isinstance(e, CCT) else e
400
401    assert_equal_fn(tree_map(unwrap, actual), expected)
402
403def gather_leaf_tensors(args, kwargs):
404    leaf_tensors = []
405    args, args_spec = tree_flatten(args)
406    kwargs, kwargs_spec = tree_flatten(kwargs)
407    args = args + kwargs
408    for arg in args:
409        if not isinstance(arg, torch.Tensor):
410            continue
411        if arg.requires_grad:
412            leaf_tensors.append(arg)
413    return leaf_tensors
414
415
416def compute_expected_grads(op, args, kwargs, output_process_fn_grad=None, gradcheck_wrapper=None):
417    if gradcheck_wrapper is None:
418        results = op(*args, **kwargs)
419    else:
420        results = gradcheck_wrapper(op, *args, **kwargs)
421
422    if output_process_fn_grad is not None:
423        results = output_process_fn_grad(results)
424
425    flat_results = pytree.tree_leaves(results)
426    flat_results = [r for r in flat_results if isinstance(r, torch.Tensor)]
427    flat_diff_results = [r for r in flat_results if r.requires_grad]
428    assert len(flat_diff_results) > 0
429
430    grads = [torch.ones(r.shape, device=r.device, dtype=r.dtype) for r in flat_diff_results]
431    leaf_tensors = gather_leaf_tensors(args, kwargs)
432    assert len(leaf_tensors) > 0
433    return torch.autograd.grad(flat_diff_results, leaf_tensors,
434                               grads, allow_unused=True, retain_graph=True)
435
436
437# Checks if the backward formula is composite compliant by testing
438# all possible permutations of {inputs, grad_outputs} being
439# CompositeCompliantTensor or regular Tensors.
440#
441# NB: it is important that op is accepted as a Callable and not an OpInfo,
442# this means we can apply check_backward_formula to things that aren't OpInfos
443# while debugging.
444def check_backward_formula(op: Callable, args, kwargs,
445                           output_process_fn_grad=None,
446                           gradcheck_wrapper=None, assert_equal_fn=None):
447    CCT, cct_mode = generate_cct_and_mode()
448
449    expected = compute_expected_grads(op, args, kwargs, output_process_fn_grad, gradcheck_wrapper)
450
451    for choice in generate_subclass_choices_args_kwargs(args, kwargs, CCT, cct_mode):
452        new_args, new_kwargs, which_args_are_wrapped, which_kwargs_are_wrapped = choice
453        leaf_tensors = gather_leaf_tensors(new_args, new_kwargs)
454        assert len(leaf_tensors) > 0
455
456        try:
457            if gradcheck_wrapper is None:
458                results = op(*new_args, **new_kwargs)
459            else:
460                results = gradcheck_wrapper(op, *new_args, **new_kwargs)
461            if output_process_fn_grad is not None:
462                results = output_process_fn_grad(results)
463        # see NOTE: [What errors are Composite Compliance trying to catch?]
464        except RuntimeError as err:
465            raise_composite_compliance_error(
466                err,
467                f"- wrapped_args: {which_args_are_wrapped}\n"
468                f"- wrapped_kwargs: {which_kwargs_are_wrapped}\n"
469            )
470
471        flat_results = pytree.tree_leaves(results)
472        flat_results = [r for r in flat_results if isinstance(r, torch.Tensor)]
473        flat_diff_results = [r for r in flat_results if r.requires_grad]
474        assert len(flat_diff_results) > 0
475
476        # NB: ones, not ones_like, so we get a regular Tensor here
477        grads = [torch.ones(r.shape, device=r.device, dtype=r.dtype)
478                 for r in flat_diff_results]
479        for flat_new_grads, which_grad_is_batched in generate_subclass_choices(grads, CCT, cct_mode):
480            try:
481                actual = torch.autograd.grad(flat_diff_results, leaf_tensors, flat_new_grads,
482                                             allow_unused=True, retain_graph=True)
483            # see NOTE: [What errors are Composite Compliance trying to catch?]
484            except RuntimeError as err:
485                raise_composite_compliance_error(
486                    err,
487                    f"- wrapped_args: {which_args_are_wrapped}\n"
488                    f"- wrapped_kwargs: {which_kwargs_are_wrapped}\n"
489                    f"- wrapped_grads: {which_grad_is_batched}\n"
490                )
491
492            def unwrap(e):
493                return e.elem if isinstance(e, CCT) else e
494
495            assert_equal_fn(tuple(map(unwrap, actual)), expected, equal_nan=True)
496
497# Checks if the forward AD formula is composite compliant by testing
498# all possible permutations of {primals, tangents} being
499# CompositeCompliantTensor or regular Tensors.
500#
501# NB: it is important that op is accepted as a Callable and not an OpInfo,
502# this means we can apply check_forward_ad_formula to things that aren't OpInfos
503# while debugging.
504def check_forward_ad_formula(op: Callable, args, kwargs, gradcheck_wrapper=None, assert_equal_fn=None):
505    CCT, cct_mode = generate_cct_and_mode(autograd_view_consistency=False)
506
507    def maybe_tangent(t):
508        assert type(t) is not CCT
509        # Generate `tangent` tensor
510        # if given object is a Tensor and requires grad is set.
511        if isinstance(t, torch.Tensor) and t.requires_grad:
512            return torch.randn_like(t)
513        elif is_tensorlist(t):
514            return [torch.randn_like(e) if e.requires_grad else None for e in t]
515        return None
516
517    tangent_args = tuple(maybe_tangent(arg) for arg in args)
518    flat_kwargs, spec = tree_flatten(kwargs)
519    flat_tangent_kwargs = tuple(maybe_tangent(arg) for arg in flat_kwargs)
520    tangent_kwargs = tree_unflatten(flat_tangent_kwargs, spec)
521
522    with fwAD.dual_level():
523        def maybe_make_dual(dual):
524            # Returns dual tensor if primal is a tensor/tensor subclass
525            # with requires_grad set.
526            primal, tangent = dual
527            if isinstance(primal, torch.Tensor) and primal.requires_grad:
528                return fwAD.make_dual(primal.detach(), tangent)
529            elif is_tensorlist(primal):
530                return tuple(fwAD.make_dual(pri.detach(), tang) if tang is not None else pri
531                             for pri, tang in zip(primal, tangent))
532            return primal
533
534        def compute_expected_grad(args, tangent_args, kwargs, tangent_kwargs):
535            op_args = tuple(map(maybe_make_dual, zip(args, tangent_args)))
536            op_kwargs = {k: maybe_make_dual((v, tangent_kwargs[k])) for k, v in kwargs.items()}
537
538            if gradcheck_wrapper is None:
539                return op(*op_args, **op_kwargs)
540            return gradcheck_wrapper(op, *op_args, **op_kwargs)
541
542        expected = compute_expected_grad(args, tangent_args, kwargs, tangent_kwargs)
543        expected = tree_map(fwAD.unpack_dual, expected)
544        expected_primals = tree_map(lambda x: x.primal, expected)
545        expected_tangents = tree_map(lambda x: x.tangent, expected)
546
547        # Permutations of arg and kwargs in CCT.
548        for choice in generate_subclass_choices_args_kwargs(args, kwargs, CCT, cct_mode):
549            new_args, new_kwargs, which_args_are_wrapped, which_kwargs_are_wrapped = choice
550
551            # Permutations tangent arg and tangent kwargs in CCT.
552            for tang_choice in generate_subclass_choices_args_kwargs(tangent_args, tangent_kwargs, CCT, cct_mode):
553                new_tang_args, new_tang_kwargs, \
554                    which_tang_args_are_wrapped, which_tang_kwargs_are_wrapped = tang_choice
555
556                op_args = tuple(map(maybe_make_dual, zip(new_args, new_tang_args)))
557                op_kwargs = {k: maybe_make_dual((v, new_tang_kwargs[k])) for k, v in new_kwargs.items()}
558
559                try:
560                    if gradcheck_wrapper is None:
561                        actual = op(*op_args, **op_kwargs)
562                    else:
563                        actual = gradcheck_wrapper(op, *op_args, **op_kwargs)
564                # see NOTE: [What errors are Composite Compliance trying to catch?]
565                except RuntimeError as err:
566                    raise_composite_compliance_error(
567                        err,
568                        f"- wrapped_args: {which_args_are_wrapped}\n"
569                        f"- wrapped_kwargs: {which_kwargs_are_wrapped}\n"
570                        f"- wrapped_tangent_args: {which_tang_args_are_wrapped}\n"
571                        f"- wrapped_tangent_kwargs: {which_tang_kwargs_are_wrapped}\n"
572                    )
573
574                def unwrap(e):
575                    return e.elem if isinstance(e, CCT) else e
576
577                actual = tree_map(fwAD.unpack_dual, actual)
578                actual_primals = tree_map(lambda x: unwrap(x.primal), actual)
579                actual_tangents = tree_map(lambda x: unwrap(x.tangent), actual)
580                assert_equal_fn(actual_primals, expected_primals, equal_nan=True)
581                assert_equal_fn(actual_tangents, expected_tangents, equal_nan=True)
582