xref: /aosp_15_r20/external/pytorch/torch/testing/_internal/two_tensor.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: ignore-errors
2
3import torch
4import torch.utils._pytree as pytree
5from torch.utils._python_dispatch import return_and_correct_aliasing
6
7
8# A simple tensor subclass that holds two tensors internally, and runs every op on both tensors.
9class TwoTensor(torch.Tensor):
10    @staticmethod
11    def __new__(cls, a, b):
12        assert (
13            a.device == b.device
14            and a.layout == b.layout
15            and a.requires_grad == b.requires_grad
16            and a.dtype == b.dtype
17        )
18        # I guess it would be more accurate to represent the shape as torch.cat(a, b).shape
19        shape = a.shape
20        kwargs = {}
21        kwargs["strides"] = a.stride()
22        kwargs["storage_offset"] = a.storage_offset()
23        kwargs["device"] = a.device
24        kwargs["layout"] = a.layout
25        kwargs["requires_grad"] = a.requires_grad
26        kwargs["dtype"] = a.dtype
27        out = torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs)
28
29        assert a.shape == b.shape
30        assert a.stride() == b.stride()
31        assert a.storage_offset() == b.storage_offset()
32        return out
33
34    def __init__(self, a, b):
35        self.a = a
36        self.b = b
37
38    def __repr__(self):
39        a_repr = repr(self.a)
40        b_repr = repr(self.b)
41        return f"TwoTensor({a_repr}, {b_repr})"
42
43    def __tensor_flatten__(self):
44        return ["a", "b"], None
45
46    @staticmethod
47    def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride):
48        assert meta is None
49        a, b = inner_tensors["a"], inner_tensors["b"]
50        return TwoTensor(a, b)
51
52    @classmethod
53    def __torch_dispatch__(cls, func, types, args, kwargs):
54        if kwargs is None:
55            kwargs = {}
56        args_a = pytree.tree_map_only(TwoTensor, lambda x: x.a, args)
57        args_b = pytree.tree_map_only(TwoTensor, lambda x: x.b, args)
58
59        kwargs_a = pytree.tree_map_only(TwoTensor, lambda x: x.a, kwargs)
60        kwargs_b = pytree.tree_map_only(TwoTensor, lambda x: x.b, kwargs)
61
62        out_a = func(*args_a, **kwargs_a)
63        out_b = func(*args_b, **kwargs_b)
64        out_a_flat, spec = pytree.tree_flatten(out_a)
65        out_b_flat = pytree.tree_leaves(out_b)
66        # for aten ops that return non-tensors, just assume that
67        # our two inner tensors return the same value
68        out_flat = [
69            TwoTensor(o_a, o_b) if isinstance(o_a, torch.Tensor) else o_a
70            for o_a, o_b in zip(out_a_flat, out_b_flat)
71        ]
72        out = pytree.tree_unflatten(out_flat, spec)
73        from torch._higher_order_ops.cond import cond_op
74
75        if func is cond_op:
76            return out
77        else:
78            return return_and_correct_aliasing(func, args, kwargs, out)
79
80
81class TwoTensorMode(torch.utils._python_dispatch.TorchDispatchMode):
82    def __torch_dispatch__(self, func, types, args=(), kwargs=None):
83        out = func(*args, **kwargs)
84        if torch._subclasses.fake_tensor._is_tensor_constructor(func):
85            out = TwoTensor(out, out.clone())
86        return out
87