1import time 2 3import torchvision.models as models 4from opacus import PrivacyEngine 5from opacus.utils.module_modification import convert_batchnorm_modules 6 7import torch 8import torch.nn as nn 9from functorch import grad, make_functional, vmap 10 11 12device = "cuda" 13batch_size = 128 14torch.manual_seed(0) 15 16model_functorch = convert_batchnorm_modules(models.resnet18(num_classes=10)) 17model_functorch = model_functorch.to(device) 18criterion = nn.CrossEntropyLoss() 19 20images = torch.randn(batch_size, 3, 32, 32, device=device) 21targets = torch.randint(0, 10, (batch_size,), device=device) 22func_model, weights = make_functional(model_functorch) 23 24 25def compute_loss(weights, image, target): 26 images = image.unsqueeze(0) 27 targets = target.unsqueeze(0) 28 output = func_model(weights, images) 29 loss = criterion(output, targets) 30 return loss 31 32 33def functorch_per_sample_grad(): 34 compute_grad = grad(compute_loss) 35 compute_per_sample_grad = vmap(compute_grad, (None, 0, 0)) 36 37 start = time.time() 38 result = compute_per_sample_grad(weights, images, targets) 39 torch.cuda.synchronize() 40 end = time.time() 41 42 return result, end - start # end - start in seconds 43 44 45torch.manual_seed(0) 46model_opacus = convert_batchnorm_modules(models.resnet18(num_classes=10)) 47model_opacus = model_opacus.to(device) 48criterion = nn.CrossEntropyLoss() 49for p_f, p_o in zip(model_functorch.parameters(), model_opacus.parameters()): 50 assert torch.allclose(p_f, p_o) # Sanity check 51 52privacy_engine = PrivacyEngine( 53 model_opacus, 54 sample_rate=0.01, 55 alphas=[10, 100], 56 noise_multiplier=1, 57 max_grad_norm=10000.0, 58) 59 60 61def opacus_per_sample_grad(): 62 start = time.time() 63 output = model_opacus(images) 64 loss = criterion(output, targets) 65 loss.backward() 66 torch.cuda.synchronize() 67 end = time.time() 68 expected = [p.grad_sample for p in model_opacus.parameters()] 69 for p in model_opacus.parameters(): 70 delattr(p, "grad_sample") 71 p.grad = None 72 return expected, end - start 73 74 75for _ in range(5): 76 _, seconds = functorch_per_sample_grad() 77 print(seconds) 78 79result, seconds = functorch_per_sample_grad() 80print(seconds) 81 82for _ in range(5): 83 _, seconds = opacus_per_sample_grad() 84 print(seconds) 85 86expected, seconds = opacus_per_sample_grad() 87print(seconds) 88 89result = [r.detach() for r in result] 90print(len(result)) 91 92# TODO: The following shows that the per-sample-grads computed are different. 93# This concerns me a little; we should compare to a source of truth. 94# for i, (r, e) in enumerate(list(zip(result, expected))[::-1]): 95# if torch.allclose(r, e, rtol=1e-5): 96# continue 97# print(-(i+1), ((r - e)/(e + 0.000001)).abs().max()) 98