xref: /aosp_15_r20/external/pytorch/test/linear.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import torch
2
3
4class LinearMod(torch.nn.Linear):
5    def __init__(self, *args, **kwargs):
6        super().__init__(*args, **kwargs)
7
8    def forward(self, input):
9        return torch._C._nn.linear(input, self.weight, self.bias)
10
11
12print(torch.jit.trace(LinearMod(20, 20), torch.rand([20, 20])).graph)
13