1# mypy: ignore-errors 2 3import torch 4import torch.utils._pytree as pytree 5from torch.utils._python_dispatch import return_and_correct_aliasing 6 7 8# A simple tensor subclass that holds two tensors internally, and runs every op on both tensors. 9class TwoTensor(torch.Tensor): 10 @staticmethod 11 def __new__(cls, a, b): 12 assert ( 13 a.device == b.device 14 and a.layout == b.layout 15 and a.requires_grad == b.requires_grad 16 and a.dtype == b.dtype 17 ) 18 # I guess it would be more accurate to represent the shape as torch.cat(a, b).shape 19 shape = a.shape 20 kwargs = {} 21 kwargs["strides"] = a.stride() 22 kwargs["storage_offset"] = a.storage_offset() 23 kwargs["device"] = a.device 24 kwargs["layout"] = a.layout 25 kwargs["requires_grad"] = a.requires_grad 26 kwargs["dtype"] = a.dtype 27 out = torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) 28 29 assert a.shape == b.shape 30 assert a.stride() == b.stride() 31 assert a.storage_offset() == b.storage_offset() 32 return out 33 34 def __init__(self, a, b): 35 self.a = a 36 self.b = b 37 38 def __repr__(self): 39 a_repr = repr(self.a) 40 b_repr = repr(self.b) 41 return f"TwoTensor({a_repr}, {b_repr})" 42 43 def __tensor_flatten__(self): 44 return ["a", "b"], None 45 46 @staticmethod 47 def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride): 48 assert meta is None 49 a, b = inner_tensors["a"], inner_tensors["b"] 50 return TwoTensor(a, b) 51 52 @classmethod 53 def __torch_dispatch__(cls, func, types, args, kwargs): 54 if kwargs is None: 55 kwargs = {} 56 args_a = pytree.tree_map_only(TwoTensor, lambda x: x.a, args) 57 args_b = pytree.tree_map_only(TwoTensor, lambda x: x.b, args) 58 59 kwargs_a = pytree.tree_map_only(TwoTensor, lambda x: x.a, kwargs) 60 kwargs_b = pytree.tree_map_only(TwoTensor, lambda x: x.b, kwargs) 61 62 out_a = func(*args_a, **kwargs_a) 63 out_b = func(*args_b, **kwargs_b) 64 out_a_flat, spec = pytree.tree_flatten(out_a) 65 out_b_flat = pytree.tree_leaves(out_b) 66 # for aten ops that return non-tensors, just assume that 67 # our two inner tensors return the same value 68 out_flat = [ 69 TwoTensor(o_a, o_b) if isinstance(o_a, torch.Tensor) else o_a 70 for o_a, o_b in zip(out_a_flat, out_b_flat) 71 ] 72 out = pytree.tree_unflatten(out_flat, spec) 73 from torch._higher_order_ops.cond import cond_op 74 75 if func is cond_op: 76 return out 77 else: 78 return return_and_correct_aliasing(func, args, kwargs, out) 79 80 81class TwoTensorMode(torch.utils._python_dispatch.TorchDispatchMode): 82 def __torch_dispatch__(self, func, types, args=(), kwargs=None): 83 out = func(*args, **kwargs) 84 if torch._subclasses.fake_tensor._is_tensor_constructor(func): 85 out = TwoTensor(out, out.clone()) 86 return out 87