xref: /aosp_15_r20/external/pytorch/torch/testing/_internal/custom_tensor.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: ignore-errors
2
3import torch
4import torch.utils._pytree as pytree
5from torch.utils._python_dispatch import return_and_correct_aliasing
6
7
8# A simple tensor subclass that holds a tensor with custom metadata and custom method
9class ConstantExtraMetadataTensor(torch.Tensor):
10    @staticmethod
11    def __new__(cls, elem):
12        shape = elem.shape
13        kwargs = {}
14        kwargs["strides"] = elem.stride()
15        kwargs["storage_offset"] = elem.storage_offset()
16        kwargs["device"] = elem.device
17        kwargs["layout"] = elem.layout
18        kwargs["requires_grad"] = elem.requires_grad
19        kwargs["dtype"] = elem.dtype
20        return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs)
21
22    def __init__(self, elem):
23        self.elem = elem
24        self.constant_attribute = 4
25
26    def __repr__(self):
27        inner_repr = repr(self.elem)
28        return f"CustomTensor({inner_repr})"
29
30    def __tensor_flatten__(self):
31        return ["elem"], self.constant_attribute
32
33    def add_constant(self, a):
34        self.constant_attribute += a
35
36    @staticmethod
37    def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride):
38        assert meta is not None
39        elem = inner_tensors["elem"]
40        out = ConstantExtraMetadataTensor(elem)
41        out.constant_attribute = meta
42        return out
43
44    @classmethod
45    def __torch_dispatch__(cls, func, types, args, kwargs):
46        if kwargs is None:
47            kwargs = {}
48        args_inner = pytree.tree_map_only(
49            ConstantExtraMetadataTensor, lambda x: x.elem, args
50        )
51
52        kwargs_inner = pytree.tree_map_only(
53            ConstantExtraMetadataTensor, lambda x: x.elem, kwargs
54        )
55
56        out_inner = func(*args_inner, **kwargs_inner)
57        out_inner_flat, spec = pytree.tree_flatten(out_inner)
58        # for aten ops that return non-tensors, just assume that
59        # our cust inner tensors return the same value
60        out_flat = [
61            ConstantExtraMetadataTensor(o_inner)
62            if isinstance(o_inner, torch.Tensor)
63            else o_inner
64            for o_inner in out_inner_flat
65        ]
66        out = pytree.tree_unflatten(out_flat, spec)
67        return return_and_correct_aliasing(func, args, kwargs, out)
68