# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. # pyre-unsafe import torch.nn as nn from torch.nn import functional as F # Basic Net for XOR class Net(nn.Module): def __init__(self): super().__init__() self.linear = nn.Linear(2, 10) self.linear2 = nn.Linear(10, 2) def forward(self, x): return self.linear2(F.sigmoid(self.linear(x))) # On device training requires the loss to be embedded in the model (and be the first output). # We wrap the original model here and add the loss calculation. This will be the model we export. class TrainingNet(nn.Module): def __init__(self, net): super().__init__() self.net = net self.loss = nn.CrossEntropyLoss() def forward(self, input, label): pred = self.net(input) return self.loss(pred, label), pred.detach().argmax(dim=1)