1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-unsafe 8 9import torch.nn as nn 10from torch.nn import functional as F 11 12 13# Basic Net for XOR 14class Net(nn.Module): 15 def __init__(self): 16 super().__init__() 17 self.linear = nn.Linear(2, 10) 18 self.linear2 = nn.Linear(10, 2) 19 20 def forward(self, x): 21 return self.linear2(F.sigmoid(self.linear(x))) 22 23 24# On device training requires the loss to be embedded in the model (and be the first output). 25# We wrap the original model here and add the loss calculation. This will be the model we export. 26class TrainingNet(nn.Module): 27 def __init__(self, net): 28 super().__init__() 29 self.net = net 30 self.loss = nn.CrossEntropyLoss() 31 32 def forward(self, input, label): 33 pred = self.net(input) 34 return self.loss(pred, label), pred.detach().argmax(dim=1) 35