xref: /aosp_15_r20/external/pytorch/functorch/benchmarks/per_sample_grads.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import time
2
3import torchvision.models as models
4from opacus import PrivacyEngine
5from opacus.utils.module_modification import convert_batchnorm_modules
6
7import torch
8import torch.nn as nn
9from functorch import grad, make_functional, vmap
10
11
12device = "cuda"
13batch_size = 128
14torch.manual_seed(0)
15
16model_functorch = convert_batchnorm_modules(models.resnet18(num_classes=10))
17model_functorch = model_functorch.to(device)
18criterion = nn.CrossEntropyLoss()
19
20images = torch.randn(batch_size, 3, 32, 32, device=device)
21targets = torch.randint(0, 10, (batch_size,), device=device)
22func_model, weights = make_functional(model_functorch)
23
24
25def compute_loss(weights, image, target):
26    images = image.unsqueeze(0)
27    targets = target.unsqueeze(0)
28    output = func_model(weights, images)
29    loss = criterion(output, targets)
30    return loss
31
32
33def functorch_per_sample_grad():
34    compute_grad = grad(compute_loss)
35    compute_per_sample_grad = vmap(compute_grad, (None, 0, 0))
36
37    start = time.time()
38    result = compute_per_sample_grad(weights, images, targets)
39    torch.cuda.synchronize()
40    end = time.time()
41
42    return result, end - start  # end - start in seconds
43
44
45torch.manual_seed(0)
46model_opacus = convert_batchnorm_modules(models.resnet18(num_classes=10))
47model_opacus = model_opacus.to(device)
48criterion = nn.CrossEntropyLoss()
49for p_f, p_o in zip(model_functorch.parameters(), model_opacus.parameters()):
50    assert torch.allclose(p_f, p_o)  # Sanity check
51
52privacy_engine = PrivacyEngine(
53    model_opacus,
54    sample_rate=0.01,
55    alphas=[10, 100],
56    noise_multiplier=1,
57    max_grad_norm=10000.0,
58)
59
60
61def opacus_per_sample_grad():
62    start = time.time()
63    output = model_opacus(images)
64    loss = criterion(output, targets)
65    loss.backward()
66    torch.cuda.synchronize()
67    end = time.time()
68    expected = [p.grad_sample for p in model_opacus.parameters()]
69    for p in model_opacus.parameters():
70        delattr(p, "grad_sample")
71        p.grad = None
72    return expected, end - start
73
74
75for _ in range(5):
76    _, seconds = functorch_per_sample_grad()
77    print(seconds)
78
79result, seconds = functorch_per_sample_grad()
80print(seconds)
81
82for _ in range(5):
83    _, seconds = opacus_per_sample_grad()
84    print(seconds)
85
86expected, seconds = opacus_per_sample_grad()
87print(seconds)
88
89result = [r.detach() for r in result]
90print(len(result))
91
92# TODO: The following shows that the per-sample-grads computed are different.
93# This concerns me a little; we should compare to a source of truth.
94# for i, (r, e) in enumerate(list(zip(result, expected))[::-1]):
95#     if torch.allclose(r, e, rtol=1e-5):
96#         continue
97#     print(-(i+1), ((r - e)/(e + 0.000001)).abs().max())
98