1*da0073e9SAndroid Build Coastguard Workerimport torch 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Worker 4*da0073e9SAndroid Build Coastguard WorkerNUM_REPEATS = 1000 5*da0073e9SAndroid Build Coastguard WorkerNUM_REPEAT_OF_REPEATS = 1000 6*da0073e9SAndroid Build Coastguard Worker 7*da0073e9SAndroid Build Coastguard Worker 8*da0073e9SAndroid Build Coastguard Workerclass SubTensor(torch.Tensor): 9*da0073e9SAndroid Build Coastguard Worker pass 10*da0073e9SAndroid Build Coastguard Worker 11*da0073e9SAndroid Build Coastguard Worker 12*da0073e9SAndroid Build Coastguard Workerclass WithTorchFunction: 13*da0073e9SAndroid Build Coastguard Worker def __init__(self, data, requires_grad=False): 14*da0073e9SAndroid Build Coastguard Worker if isinstance(data, torch.Tensor): 15*da0073e9SAndroid Build Coastguard Worker self._tensor = data 16*da0073e9SAndroid Build Coastguard Worker return 17*da0073e9SAndroid Build Coastguard Worker 18*da0073e9SAndroid Build Coastguard Worker self._tensor = torch.tensor(data, requires_grad=requires_grad) 19*da0073e9SAndroid Build Coastguard Worker 20*da0073e9SAndroid Build Coastguard Worker @classmethod 21*da0073e9SAndroid Build Coastguard Worker def __torch_function__(cls, func, types, args=(), kwargs=None): 22*da0073e9SAndroid Build Coastguard Worker if kwargs is None: 23*da0073e9SAndroid Build Coastguard Worker kwargs = {} 24*da0073e9SAndroid Build Coastguard Worker 25*da0073e9SAndroid Build Coastguard Worker return WithTorchFunction(args[0]._tensor + args[1]._tensor) 26*da0073e9SAndroid Build Coastguard Worker 27*da0073e9SAndroid Build Coastguard Worker 28*da0073e9SAndroid Build Coastguard Workerclass SubWithTorchFunction(torch.Tensor): 29*da0073e9SAndroid Build Coastguard Worker @classmethod 30*da0073e9SAndroid Build Coastguard Worker def __torch_function__(cls, func, types, args=(), kwargs=None): 31*da0073e9SAndroid Build Coastguard Worker if kwargs is None: 32*da0073e9SAndroid Build Coastguard Worker kwargs = {} 33*da0073e9SAndroid Build Coastguard Worker 34*da0073e9SAndroid Build Coastguard Worker return super().__torch_function__(func, types, args, kwargs) 35