xref: /aosp_15_r20/external/pytorch/test/onnx/model_defs/emb_seq.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import torch.nn as nn
2
3
4class EmbeddingNetwork1(nn.Module):
5    def __init__(self, dim=5):
6        super().__init__()
7        self.emb = nn.Embedding(10, dim)
8        self.lin1 = nn.Linear(dim, 1)
9        self.seq = nn.Sequential(
10            self.emb,
11            self.lin1,
12        )
13
14    def forward(self, input):
15        return self.seq(input)
16
17
18class EmbeddingNetwork2(nn.Module):
19    def __init__(self, in_space=10, dim=3):
20        super().__init__()
21        self.embedding = nn.Embedding(in_space, dim)
22        self.seq = nn.Sequential(self.embedding, nn.Linear(dim, 1), nn.Sigmoid())
23
24    def forward(self, indices):
25        return self.seq(indices)
26