import torch NUM_REPEATS = 1000 NUM_REPEAT_OF_REPEATS = 1000 class SubTensor(torch.Tensor): pass class WithTorchFunction: def __init__(self, data, requires_grad=False): if isinstance(data, torch.Tensor): self._tensor = data return self._tensor = torch.tensor(data, requires_grad=requires_grad) @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): if kwargs is None: kwargs = {} return WithTorchFunction(args[0]._tensor + args[1]._tensor) class SubWithTorchFunction(torch.Tensor): @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): if kwargs is None: kwargs = {} return super().__torch_function__(func, types, args, kwargs)