xref: /aosp_15_r20/external/pytorch/functorch/examples/compilation/linear_train.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Copyright (c) Facebook, Inc. and its affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import time
8
9import torch
10import torch.nn as nn
11from functorch import make_functional
12from functorch.compile import nnc_jit
13
14
15torch._C._jit_override_can_fuse_on_cpu(True)
16
17
18def bench(f, iters=100, warmup=10):
19    for _ in range(warmup):
20        f()
21    begin = time.time()
22    for _ in range(iters):
23        f()
24    print(time.time() - begin)
25
26
27class Foo(nn.Module):
28    def __init__(self, num_layers=3, features=100):
29        super().__init__()
30        mods = []
31        for _ in range(num_layers):
32            mods.append(nn.Linear(features, features, bias=False))
33        self.mod = nn.Sequential(*mods)
34
35    def forward(self, x):
36        return (self.mod(x) ** 2).sum()
37
38
39batch_size = 16
40features = 64
41num_layers = 8
42inp = torch.randn((batch_size, features))
43
44mod = Foo(num_layers, features)
45
46jit_mod = torch.jit.script(mod)
47
48func_model, weights = make_functional(mod)
49lr = 1.0
50
51
52def functional_step(x, weights):
53    weights = [weight.detach().requires_grad_() for weight in weights]
54    out = func_model(weights, x)
55    out.backward()
56    new_weights = [weight - lr * weight.grad for weight in weights]
57    return out, new_weights
58
59
60optim = torch.optim.SGD(
61    jit_mod.parameters(), lr=lr, momentum=0, dampening=0, weight_decay=0
62)
63
64
65def jit_step(x, weights):
66    optim.zero_grad()
67    loss = jit_mod(x)
68    loss.backward()
69    optim.step()
70    return loss, None
71
72
73def train(train_step, weights):
74    torch.manual_seed(16)
75    train_step(inp, weights)
76    begin = time.time()
77    for itr in range(1000):
78        loss, weights = train_step(torch.randn(batch_size, features), weights)
79        if itr % 200 == 0:
80            print(f"Loss at {itr}: {loss}")
81    print("Time taken: ", time.time() - begin)
82    print()
83
84
85grad_pt = functional_step
86grad_nnc = nnc_jit(functional_step)
87
88print("Starting PT training")
89train(grad_pt, weights)
90
91print("Starting NNC training")
92train(grad_nnc, weights)
93
94print("Starting JIT training")
95train(jit_step, None)
96