1# mypy: ignore-errors 2 3import torch 4import torch.utils._pytree as pytree 5from torch.utils._python_dispatch import return_and_correct_aliasing 6 7 8# A simple tensor subclass that holds a tensor with custom metadata and custom method 9class ConstantExtraMetadataTensor(torch.Tensor): 10 @staticmethod 11 def __new__(cls, elem): 12 shape = elem.shape 13 kwargs = {} 14 kwargs["strides"] = elem.stride() 15 kwargs["storage_offset"] = elem.storage_offset() 16 kwargs["device"] = elem.device 17 kwargs["layout"] = elem.layout 18 kwargs["requires_grad"] = elem.requires_grad 19 kwargs["dtype"] = elem.dtype 20 return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) 21 22 def __init__(self, elem): 23 self.elem = elem 24 self.constant_attribute = 4 25 26 def __repr__(self): 27 inner_repr = repr(self.elem) 28 return f"CustomTensor({inner_repr})" 29 30 def __tensor_flatten__(self): 31 return ["elem"], self.constant_attribute 32 33 def add_constant(self, a): 34 self.constant_attribute += a 35 36 @staticmethod 37 def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride): 38 assert meta is not None 39 elem = inner_tensors["elem"] 40 out = ConstantExtraMetadataTensor(elem) 41 out.constant_attribute = meta 42 return out 43 44 @classmethod 45 def __torch_dispatch__(cls, func, types, args, kwargs): 46 if kwargs is None: 47 kwargs = {} 48 args_inner = pytree.tree_map_only( 49 ConstantExtraMetadataTensor, lambda x: x.elem, args 50 ) 51 52 kwargs_inner = pytree.tree_map_only( 53 ConstantExtraMetadataTensor, lambda x: x.elem, kwargs 54 ) 55 56 out_inner = func(*args_inner, **kwargs_inner) 57 out_inner_flat, spec = pytree.tree_flatten(out_inner) 58 # for aten ops that return non-tensors, just assume that 59 # our cust inner tensors return the same value 60 out_flat = [ 61 ConstantExtraMetadataTensor(o_inner) 62 if isinstance(o_inner, torch.Tensor) 63 else o_inner 64 for o_inner in out_inner_flat 65 ] 66 out = pytree.tree_unflatten(out_flat, spec) 67 return return_and_correct_aliasing(func, args, kwargs, out) 68