xref: /aosp_15_r20/external/pytorch/test/package/package_c/test_module.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: package/deploy"]
2
3import torch
4
5
6try:
7    from torchvision.models import resnet18
8
9    class TorchVisionTest(torch.nn.Module):
10        def __init__(self) -> None:
11            super().__init__()
12            self.tvmod = resnet18()
13
14        def forward(self, x):
15            x = a_non_torch_leaf(x, x)
16            return torch.relu(x + 3.0)
17
18except ImportError:
19    pass
20
21
22def a_non_torch_leaf(a, b):
23    return a + b
24