xref: /aosp_15_r20/external/executorch/extension/training/examples/XOR/model.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-unsafe
8
9import torch.nn as nn
10from torch.nn import functional as F
11
12
13# Basic Net for XOR
14class Net(nn.Module):
15    def __init__(self):
16        super().__init__()
17        self.linear = nn.Linear(2, 10)
18        self.linear2 = nn.Linear(10, 2)
19
20    def forward(self, x):
21        return self.linear2(F.sigmoid(self.linear(x)))
22
23
24# On device training requires the loss to be embedded in the model (and be the first output).
25# We wrap the original model here and add the loss calculation. This will be the model we export.
26class TrainingNet(nn.Module):
27    def __init__(self, net):
28        super().__init__()
29        self.net = net
30        self.loss = nn.CrossEntropyLoss()
31
32    def forward(self, input, label):
33        pred = self.net(input)
34        return self.loss(pred, label), pred.detach().argmax(dim=1)
35