xref: /aosp_15_r20/external/pytorch/test/package/package_a/fake_interface.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import torch
2from torch import Tensor
3
4
5@torch.jit.interface
6class ModuleInterface(torch.nn.Module):
7    def one(self, inp1: Tensor, inp2: Tensor) -> Tensor:
8        pass
9
10
11class OrigModule(torch.nn.Module):
12    """A module that implements ModuleInterface."""
13
14    def one(self, inp1: Tensor, inp2: Tensor) -> Tensor:
15        return inp1 + inp2 + 1
16
17    def two(self, input: Tensor) -> Tensor:
18        return input + 2
19
20    def forward(self, input: Tensor) -> Tensor:
21        return input + self.one(input, input) + 1
22
23
24class NewModule(torch.nn.Module):
25    """A *different* module that implements ModuleInterface."""
26
27    def one(self, inp1: Tensor, inp2: Tensor) -> Tensor:
28        return inp1 * inp2 + 1
29
30    def forward(self, input: Tensor) -> Tensor:
31        return self.one(input, input + 1)
32
33
34class UsesInterface(torch.nn.Module):
35    proxy_mod: ModuleInterface
36
37    def __init__(self) -> None:
38        super().__init__()
39        self.proxy_mod = OrigModule()
40
41    def forward(self, input: Tensor) -> Tensor:
42        return self.proxy_mod.one(input, input)
43