1# Copyright (c) Facebook, Inc. and its affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import time 8 9import torch 10import torch.nn as nn 11from functorch import make_functional 12from functorch.compile import nnc_jit 13 14 15torch._C._jit_override_can_fuse_on_cpu(True) 16 17 18def bench(f, iters=100, warmup=10): 19 for _ in range(warmup): 20 f() 21 begin = time.time() 22 for _ in range(iters): 23 f() 24 print(time.time() - begin) 25 26 27class Foo(nn.Module): 28 def __init__(self, num_layers=3, features=100): 29 super().__init__() 30 mods = [] 31 for _ in range(num_layers): 32 mods.append(nn.Linear(features, features, bias=False)) 33 self.mod = nn.Sequential(*mods) 34 35 def forward(self, x): 36 return (self.mod(x) ** 2).sum() 37 38 39batch_size = 16 40features = 64 41num_layers = 8 42inp = torch.randn((batch_size, features)) 43 44mod = Foo(num_layers, features) 45 46jit_mod = torch.jit.script(mod) 47 48func_model, weights = make_functional(mod) 49lr = 1.0 50 51 52def functional_step(x, weights): 53 weights = [weight.detach().requires_grad_() for weight in weights] 54 out = func_model(weights, x) 55 out.backward() 56 new_weights = [weight - lr * weight.grad for weight in weights] 57 return out, new_weights 58 59 60optim = torch.optim.SGD( 61 jit_mod.parameters(), lr=lr, momentum=0, dampening=0, weight_decay=0 62) 63 64 65def jit_step(x, weights): 66 optim.zero_grad() 67 loss = jit_mod(x) 68 loss.backward() 69 optim.step() 70 return loss, None 71 72 73def train(train_step, weights): 74 torch.manual_seed(16) 75 train_step(inp, weights) 76 begin = time.time() 77 for itr in range(1000): 78 loss, weights = train_step(torch.randn(batch_size, features), weights) 79 if itr % 200 == 0: 80 print(f"Loss at {itr}: {loss}") 81 print("Time taken: ", time.time() - begin) 82 print() 83 84 85grad_pt = functional_step 86grad_nnc = nnc_jit(functional_step) 87 88print("Starting PT training") 89train(grad_pt, weights) 90 91print("Starting NNC training") 92train(grad_nnc, weights) 93 94print("Starting JIT training") 95train(jit_step, None) 96