xref: /aosp_15_r20/external/pytorch/benchmarks/overrides_benchmark/common.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Workerimport torch
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Worker
4*da0073e9SAndroid Build Coastguard WorkerNUM_REPEATS = 1000
5*da0073e9SAndroid Build Coastguard WorkerNUM_REPEAT_OF_REPEATS = 1000
6*da0073e9SAndroid Build Coastguard Worker
7*da0073e9SAndroid Build Coastguard Worker
8*da0073e9SAndroid Build Coastguard Workerclass SubTensor(torch.Tensor):
9*da0073e9SAndroid Build Coastguard Worker    pass
10*da0073e9SAndroid Build Coastguard Worker
11*da0073e9SAndroid Build Coastguard Worker
12*da0073e9SAndroid Build Coastguard Workerclass WithTorchFunction:
13*da0073e9SAndroid Build Coastguard Worker    def __init__(self, data, requires_grad=False):
14*da0073e9SAndroid Build Coastguard Worker        if isinstance(data, torch.Tensor):
15*da0073e9SAndroid Build Coastguard Worker            self._tensor = data
16*da0073e9SAndroid Build Coastguard Worker            return
17*da0073e9SAndroid Build Coastguard Worker
18*da0073e9SAndroid Build Coastguard Worker        self._tensor = torch.tensor(data, requires_grad=requires_grad)
19*da0073e9SAndroid Build Coastguard Worker
20*da0073e9SAndroid Build Coastguard Worker    @classmethod
21*da0073e9SAndroid Build Coastguard Worker    def __torch_function__(cls, func, types, args=(), kwargs=None):
22*da0073e9SAndroid Build Coastguard Worker        if kwargs is None:
23*da0073e9SAndroid Build Coastguard Worker            kwargs = {}
24*da0073e9SAndroid Build Coastguard Worker
25*da0073e9SAndroid Build Coastguard Worker        return WithTorchFunction(args[0]._tensor + args[1]._tensor)
26*da0073e9SAndroid Build Coastguard Worker
27*da0073e9SAndroid Build Coastguard Worker
28*da0073e9SAndroid Build Coastguard Workerclass SubWithTorchFunction(torch.Tensor):
29*da0073e9SAndroid Build Coastguard Worker    @classmethod
30*da0073e9SAndroid Build Coastguard Worker    def __torch_function__(cls, func, types, args=(), kwargs=None):
31*da0073e9SAndroid Build Coastguard Worker        if kwargs is None:
32*da0073e9SAndroid Build Coastguard Worker            kwargs = {}
33*da0073e9SAndroid Build Coastguard Worker
34*da0073e9SAndroid Build Coastguard Worker        return super().__torch_function__(func, types, args, kwargs)
35