1import torch.nn as nn 2 3 4class EmbeddingNetwork1(nn.Module): 5 def __init__(self, dim=5): 6 super().__init__() 7 self.emb = nn.Embedding(10, dim) 8 self.lin1 = nn.Linear(dim, 1) 9 self.seq = nn.Sequential( 10 self.emb, 11 self.lin1, 12 ) 13 14 def forward(self, input): 15 return self.seq(input) 16 17 18class EmbeddingNetwork2(nn.Module): 19 def __init__(self, in_space=10, dim=3): 20 super().__init__() 21 self.embedding = nn.Embedding(in_space, dim) 22 self.seq = nn.Sequential(self.embedding, nn.Linear(dim, 1), nn.Sigmoid()) 23 24 def forward(self, indices): 25 return self.seq(indices) 26