1import torch 2from torch import Tensor 3 4 5@torch.jit.interface 6class ModuleInterface(torch.nn.Module): 7 def one(self, inp1: Tensor, inp2: Tensor) -> Tensor: 8 pass 9 10 11class OrigModule(torch.nn.Module): 12 """A module that implements ModuleInterface.""" 13 14 def one(self, inp1: Tensor, inp2: Tensor) -> Tensor: 15 return inp1 + inp2 + 1 16 17 def two(self, input: Tensor) -> Tensor: 18 return input + 2 19 20 def forward(self, input: Tensor) -> Tensor: 21 return input + self.one(input, input) + 1 22 23 24class NewModule(torch.nn.Module): 25 """A *different* module that implements ModuleInterface.""" 26 27 def one(self, inp1: Tensor, inp2: Tensor) -> Tensor: 28 return inp1 * inp2 + 1 29 30 def forward(self, input: Tensor) -> Tensor: 31 return self.one(input, input + 1) 32 33 34class UsesInterface(torch.nn.Module): 35 proxy_mod: ModuleInterface 36 37 def __init__(self) -> None: 38 super().__init__() 39 self.proxy_mod = OrigModule() 40 41 def forward(self, input: Tensor) -> Tensor: 42 return self.proxy_mod.one(input, input) 43