1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import torch 8 9 10class LinearModel(torch.nn.Module): 11 def __init__(self): 12 super().__init__() 13 self.a = 3 * torch.ones(2, 2, dtype=torch.float) 14 self.b = 2 * torch.ones(2, 2, dtype=torch.float) 15 16 def forward(self, x: torch.Tensor): 17 out_1 = torch.mul(self.a, x) 18 out_2 = torch.add(out_1, self.b) 19 return out_2 20