1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: unknown"] 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerimport functools 4*da0073e9SAndroid Build Coastguard Workerimport unittest 5*da0073e9SAndroid Build Coastguard Worker 6*da0073e9SAndroid Build Coastguard Workerimport torch 7*da0073e9SAndroid Build Coastguard Workerimport torch.nn.functional as F 8*da0073e9SAndroid Build Coastguard Workerimport torch.utils.flop_counter 9*da0073e9SAndroid Build Coastguard Workerfrom torch._subclasses.fake_tensor import FakeTensorMode 10*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_cuda import ( 11*da0073e9SAndroid Build Coastguard Worker PLATFORM_SUPPORTS_FLASH_ATTENTION, 12*da0073e9SAndroid Build Coastguard Worker PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, 13*da0073e9SAndroid Build Coastguard Worker) 14*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import ( 15*da0073e9SAndroid Build Coastguard Worker run_tests, 16*da0073e9SAndroid Build Coastguard Worker TEST_WITH_TORCHDYNAMO, 17*da0073e9SAndroid Build Coastguard Worker TestCase, 18*da0073e9SAndroid Build Coastguard Worker skipIfRocm, 19*da0073e9SAndroid Build Coastguard Worker) 20*da0073e9SAndroid Build Coastguard Worker 21*da0073e9SAndroid Build Coastguard Workertry: 22*da0073e9SAndroid Build Coastguard Worker from torchvision import models as torchvision_models 23*da0073e9SAndroid Build Coastguard Worker 24*da0073e9SAndroid Build Coastguard Worker HAS_TORCHVISION = True 25*da0073e9SAndroid Build Coastguard Workerexcept ImportError: 26*da0073e9SAndroid Build Coastguard Worker HAS_TORCHVISION = False 27*da0073e9SAndroid Build Coastguard WorkerskipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision") 28*da0073e9SAndroid Build Coastguard Worker 29*da0073e9SAndroid Build Coastguard WorkerHAS_CUDA = torch.cuda.is_available() 30*da0073e9SAndroid Build Coastguard Worker 31*da0073e9SAndroid Build Coastguard Worker 32*da0073e9SAndroid Build Coastguard Workerdef FlopCounterMode(*args, **kwargs): 33*da0073e9SAndroid Build Coastguard Worker return torch.utils.flop_counter.FlopCounterMode(*args, **kwargs, display=False) 34*da0073e9SAndroid Build Coastguard Worker 35*da0073e9SAndroid Build Coastguard Worker 36*da0073e9SAndroid Build Coastguard Workerdef get_total_flops(mode): 37*da0073e9SAndroid Build Coastguard Worker return str(sum(v for _, v in mode.flop_counts["Global"].items())) 38*da0073e9SAndroid Build Coastguard Worker 39*da0073e9SAndroid Build Coastguard Worker 40*da0073e9SAndroid Build Coastguard Workerdef T(*shape, requires_grad=False): 41*da0073e9SAndroid Build Coastguard Worker return torch.randn(*shape, requires_grad=requires_grad) 42*da0073e9SAndroid Build Coastguard Worker 43*da0073e9SAndroid Build Coastguard Worker 44*da0073e9SAndroid Build Coastguard Worker@unittest.skipIf( 45*da0073e9SAndroid Build Coastguard Worker TEST_WITH_TORCHDYNAMO, "torchdynamo doesn't work with __torch_dispatch__ right now" 46*da0073e9SAndroid Build Coastguard Worker) 47*da0073e9SAndroid Build Coastguard Workerclass TestFlopCounter(TestCase): 48*da0073e9SAndroid Build Coastguard Worker def test_flop_counter_variety(self): 49*da0073e9SAndroid Build Coastguard Worker mod = torch.nn.Linear(9, 10) 50*da0073e9SAndroid Build Coastguard Worker with FlopCounterMode() as mode: 51*da0073e9SAndroid Build Coastguard Worker torch.mm(T(4, 5), T(5, 6)) 52*da0073e9SAndroid Build Coastguard Worker torch.addmm(T(4, 6), T(4, 5), T(5, 6), beta=0.5, alpha=0.5) 53*da0073e9SAndroid Build Coastguard Worker torch.matmul(T(5, 6), T(6, 7)) 54*da0073e9SAndroid Build Coastguard Worker torch.einsum("ab,bc->ac", T(6, 7), T(7, 8)) 55*da0073e9SAndroid Build Coastguard Worker mod(T(8, 9)) 56*da0073e9SAndroid Build Coastguard Worker 57*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline(get_total_flops(mode), """3012""") 58*da0073e9SAndroid Build Coastguard Worker 59*da0073e9SAndroid Build Coastguard Worker def test_op(self): 60*da0073e9SAndroid Build Coastguard Worker with FlopCounterMode() as mode: 61*da0073e9SAndroid Build Coastguard Worker torch.mm(T(4, 5), T(5, 6)) 62*da0073e9SAndroid Build Coastguard Worker # 4 * 6 * 2 * 5 = 240 63*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline(get_total_flops(mode), """240""") 64*da0073e9SAndroid Build Coastguard Worker 65*da0073e9SAndroid Build Coastguard Worker with mode: 66*da0073e9SAndroid Build Coastguard Worker torch.bmm(T(3, 4, 5), T(3, 5, 6)) 67*da0073e9SAndroid Build Coastguard Worker # 3 * 4 * 6 * 2 * 5 = 720 68*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline(get_total_flops(mode), """720""") 69*da0073e9SAndroid Build Coastguard Worker 70*da0073e9SAndroid Build Coastguard Worker with mode: 71*da0073e9SAndroid Build Coastguard Worker torch.addmm(T(4, 6), T(4, 5), T(5, 6)) 72*da0073e9SAndroid Build Coastguard Worker torch.addmm(T(4, 1), T(4, 5), T(5, 6)) 73*da0073e9SAndroid Build Coastguard Worker torch.addmm(T(6), T(4, 5), T(5, 6)) 74*da0073e9SAndroid Build Coastguard Worker 75*da0073e9SAndroid Build Coastguard Worker # 4 * 6 * 2 * 5 = 240 76*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline(get_total_flops(mode), """720""") 77*da0073e9SAndroid Build Coastguard Worker 78*da0073e9SAndroid Build Coastguard Worker with mode: 79*da0073e9SAndroid Build Coastguard Worker torch.baddbmm(T(3, 4, 6), T(3, 4, 5), T(3, 5, 6)) 80*da0073e9SAndroid Build Coastguard Worker 81*da0073e9SAndroid Build Coastguard Worker # 3 * 4 * 6 * 2 * 5 = 720 82*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline(get_total_flops(mode), """720""") 83*da0073e9SAndroid Build Coastguard Worker 84*da0073e9SAndroid Build Coastguard Worker with mode: 85*da0073e9SAndroid Build Coastguard Worker torch.conv2d(T(2, 3, 6, 6), T(6, 3, 4, 4), padding=1) 86*da0073e9SAndroid Build Coastguard Worker 87*da0073e9SAndroid Build Coastguard Worker # out_image_size = 2 * 5 * 5 88*da0073e9SAndroid Build Coastguard Worker # kernel_size = 4 * 4 89*da0073e9SAndroid Build Coastguard Worker # c_out = 6 90*da0073e9SAndroid Build Coastguard Worker # c_in = 3 91*da0073e9SAndroid Build Coastguard Worker # out_image_size * kernel_size * c_out * 2 * c_in 92*da0073e9SAndroid Build Coastguard Worker 93*da0073e9SAndroid Build Coastguard Worker # NB: I don't think this properly accounts for padding? 94*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline(get_total_flops(mode), """28800""") 95*da0073e9SAndroid Build Coastguard Worker 96*da0073e9SAndroid Build Coastguard Worker with mode: 97*da0073e9SAndroid Build Coastguard Worker torch.conv1d(T(2, 3, 6), T(6, 3, 4), padding=1) 98*da0073e9SAndroid Build Coastguard Worker 99*da0073e9SAndroid Build Coastguard Worker # out_image_size = 2 * 5 100*da0073e9SAndroid Build Coastguard Worker # kernel_size = 4 101*da0073e9SAndroid Build Coastguard Worker # c_out = 6 102*da0073e9SAndroid Build Coastguard Worker # c_in = 3 103*da0073e9SAndroid Build Coastguard Worker # out_image_size * kernel_size * c_out * 2 * c_in 104*da0073e9SAndroid Build Coastguard Worker 105*da0073e9SAndroid Build Coastguard Worker # NB: I don't think this properly accounts for padding? 106*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline(get_total_flops(mode), """1440""") 107*da0073e9SAndroid Build Coastguard Worker 108*da0073e9SAndroid Build Coastguard Worker def test_backward(self): 109*da0073e9SAndroid Build Coastguard Worker with FlopCounterMode() as mode: 110*da0073e9SAndroid Build Coastguard Worker a = T(4, 5, requires_grad=True) 111*da0073e9SAndroid Build Coastguard Worker a = torch.mm(a, T(5, 6)) 112*da0073e9SAndroid Build Coastguard Worker a = a.unsqueeze(0).expand(7, 4, 6) 113*da0073e9SAndroid Build Coastguard Worker a = torch.bmm(a, T(7, 6, 7)) 114*da0073e9SAndroid Build Coastguard Worker a.sum().backward() 115*da0073e9SAndroid Build Coastguard Worker 116*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline(get_total_flops(mode), """5184""") 117*da0073e9SAndroid Build Coastguard Worker 118*da0073e9SAndroid Build Coastguard Worker def test_backward_reset(self): 119*da0073e9SAndroid Build Coastguard Worker with FlopCounterMode() as mode: 120*da0073e9SAndroid Build Coastguard Worker a = T(4, 5, requires_grad=True) 121*da0073e9SAndroid Build Coastguard Worker a.mm(a.t()).sum().backward() 122*da0073e9SAndroid Build Coastguard Worker a.mm(a.t()).sum().backward() 123*da0073e9SAndroid Build Coastguard Worker 124*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline(get_total_flops(mode), """960""") 125*da0073e9SAndroid Build Coastguard Worker 126*da0073e9SAndroid Build Coastguard Worker def test_torchscript(self): 127*da0073e9SAndroid Build Coastguard Worker def foo(x): 128*da0073e9SAndroid Build Coastguard Worker return torch.mm(x, x) 129*da0073e9SAndroid Build Coastguard Worker 130*da0073e9SAndroid Build Coastguard Worker with FlopCounterMode() as mode: 131*da0073e9SAndroid Build Coastguard Worker foo(T(5, 5)) 132*da0073e9SAndroid Build Coastguard Worker unscripted_flops = get_total_flops(mode) 133*da0073e9SAndroid Build Coastguard Worker ts_foo = torch.jit.script(foo) 134*da0073e9SAndroid Build Coastguard Worker with mode: 135*da0073e9SAndroid Build Coastguard Worker ts_foo(T(5, 5)) 136*da0073e9SAndroid Build Coastguard Worker self.assertEqual(unscripted_flops, get_total_flops(mode)) 137*da0073e9SAndroid Build Coastguard Worker 138*da0073e9SAndroid Build Coastguard Worker def test_autograd_op(self): 139*da0073e9SAndroid Build Coastguard Worker class _CustomOp(torch.autograd.Function): 140*da0073e9SAndroid Build Coastguard Worker @staticmethod 141*da0073e9SAndroid Build Coastguard Worker def forward(ctx, input: torch.Tensor) -> torch.Tensor: 142*da0073e9SAndroid Build Coastguard Worker return torch.mm(input, input) 143*da0073e9SAndroid Build Coastguard Worker 144*da0073e9SAndroid Build Coastguard Worker @staticmethod 145*da0073e9SAndroid Build Coastguard Worker def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor: 146*da0073e9SAndroid Build Coastguard Worker return torch.mm(grad_output, grad_output) + torch.mm( 147*da0073e9SAndroid Build Coastguard Worker grad_output, grad_output 148*da0073e9SAndroid Build Coastguard Worker ) 149*da0073e9SAndroid Build Coastguard Worker 150*da0073e9SAndroid Build Coastguard Worker a = T(5, 5, requires_grad=True) 151*da0073e9SAndroid Build Coastguard Worker with FlopCounterMode() as mode: 152*da0073e9SAndroid Build Coastguard Worker a = _CustomOp.apply(a) 153*da0073e9SAndroid Build Coastguard Worker a.sum().backward() 154*da0073e9SAndroid Build Coastguard Worker 155*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline(get_total_flops(mode), """750""") 156*da0073e9SAndroid Build Coastguard Worker 157*da0073e9SAndroid Build Coastguard Worker def test_conv_backwards_as_decomposition(self): 158*da0073e9SAndroid Build Coastguard Worker # [conv backwards decomposition as conv forwards] 159*da0073e9SAndroid Build Coastguard Worker 160*da0073e9SAndroid Build Coastguard Worker class onlyConvs(torch.autograd.Function): 161*da0073e9SAndroid Build Coastguard Worker @staticmethod 162*da0073e9SAndroid Build Coastguard Worker def forward(inp, weight, transposed): 163*da0073e9SAndroid Build Coastguard Worker if not transposed: 164*da0073e9SAndroid Build Coastguard Worker return F.conv1d(inp, weight) 165*da0073e9SAndroid Build Coastguard Worker else: 166*da0073e9SAndroid Build Coastguard Worker return F.conv_transpose1d(inp, weight) 167*da0073e9SAndroid Build Coastguard Worker 168*da0073e9SAndroid Build Coastguard Worker @staticmethod 169*da0073e9SAndroid Build Coastguard Worker def setup_context(ctx, inputs, output): 170*da0073e9SAndroid Build Coastguard Worker inp, weight, transposed = inputs 171*da0073e9SAndroid Build Coastguard Worker ctx.save_for_backward(inp, weight) 172*da0073e9SAndroid Build Coastguard Worker ctx.transposed = transposed 173*da0073e9SAndroid Build Coastguard Worker 174*da0073e9SAndroid Build Coastguard Worker @staticmethod 175*da0073e9SAndroid Build Coastguard Worker def backward(ctx, grad_out): 176*da0073e9SAndroid Build Coastguard Worker inp, weight = ctx.saved_tensors 177*da0073e9SAndroid Build Coastguard Worker if not ctx.transposed: 178*da0073e9SAndroid Build Coastguard Worker grad_inp = F.conv_transpose1d(grad_out, weight) 179*da0073e9SAndroid Build Coastguard Worker grad_weight = F.conv1d(inp, grad_out) 180*da0073e9SAndroid Build Coastguard Worker return grad_inp, grad_weight, None 181*da0073e9SAndroid Build Coastguard Worker else: 182*da0073e9SAndroid Build Coastguard Worker grad_inp = F.conv1d(grad_out, weight) 183*da0073e9SAndroid Build Coastguard Worker grad_weight = F.conv1d( 184*da0073e9SAndroid Build Coastguard Worker grad_out.transpose(1, 0), inp.transpose(1, 0) 185*da0073e9SAndroid Build Coastguard Worker ) 186*da0073e9SAndroid Build Coastguard Worker return grad_inp, grad_weight.transpose(1, 0), None 187*da0073e9SAndroid Build Coastguard Worker 188*da0073e9SAndroid Build Coastguard Worker from torch.func import grad 189*da0073e9SAndroid Build Coastguard Worker 190*da0073e9SAndroid Build Coastguard Worker x = torch.randn(2, 3, 16, dtype=torch.float64) 191*da0073e9SAndroid Build Coastguard Worker weight = torch.randn(3, 4, 4, dtype=torch.float64) 192*da0073e9SAndroid Build Coastguard Worker 193*da0073e9SAndroid Build Coastguard Worker def boring_conv(x, weight, transposed): 194*da0073e9SAndroid Build Coastguard Worker if not transposed: 195*da0073e9SAndroid Build Coastguard Worker return F.conv1d(x, weight).pow(2).sum() 196*da0073e9SAndroid Build Coastguard Worker else: 197*da0073e9SAndroid Build Coastguard Worker return F.conv_transpose1d(x, weight).pow(2).sum() 198*da0073e9SAndroid Build Coastguard Worker 199*da0073e9SAndroid Build Coastguard Worker def only_convs(x, weight, transposed): 200*da0073e9SAndroid Build Coastguard Worker return onlyConvs.apply(x, weight, transposed).pow(2).sum() 201*da0073e9SAndroid Build Coastguard Worker 202*da0073e9SAndroid Build Coastguard Worker boring_grads = grad(boring_conv, argnums=(0, 1))(x, weight, True) 203*da0073e9SAndroid Build Coastguard Worker fun_grads = grad(only_convs, argnums=(0, 1))(x, weight, True) 204*da0073e9SAndroid Build Coastguard Worker 205*da0073e9SAndroid Build Coastguard Worker self.assertEqual(boring_grads, fun_grads) 206*da0073e9SAndroid Build Coastguard Worker 207*da0073e9SAndroid Build Coastguard Worker def test_convs(self): 208*da0073e9SAndroid Build Coastguard Worker def assert_equivalence(f, expected_forward=None): 209*da0073e9SAndroid Build Coastguard Worker with FlopCounterMode() as mode: 210*da0073e9SAndroid Build Coastguard Worker f() 211*da0073e9SAndroid Build Coastguard Worker conv_forward_flops = mode.get_flop_counts()["Global"][ 212*da0073e9SAndroid Build Coastguard Worker torch.ops.aten.convolution 213*da0073e9SAndroid Build Coastguard Worker ] 214*da0073e9SAndroid Build Coastguard Worker conv_backward_flops = mode.get_flop_counts()["Global"][ 215*da0073e9SAndroid Build Coastguard Worker torch.ops.aten.convolution_backward 216*da0073e9SAndroid Build Coastguard Worker ] 217*da0073e9SAndroid Build Coastguard Worker 218*da0073e9SAndroid Build Coastguard Worker self.assertEqual(conv_forward_flops * 2, conv_backward_flops) 219*da0073e9SAndroid Build Coastguard Worker if expected_forward is not None: 220*da0073e9SAndroid Build Coastguard Worker self.assertEqual(conv_forward_flops, expected_forward) 221*da0073e9SAndroid Build Coastguard Worker 222*da0073e9SAndroid Build Coastguard Worker x = torch.rand(1, 1, 2, 2, requires_grad=True) 223*da0073e9SAndroid Build Coastguard Worker weight = torch.randn(1, 1, 2, 2, requires_grad=True) 224*da0073e9SAndroid Build Coastguard Worker assert_equivalence(lambda: F.conv_transpose2d(x, weight).sum().backward(), 32) 225*da0073e9SAndroid Build Coastguard Worker 226*da0073e9SAndroid Build Coastguard Worker x = torch.rand(1, 1, 2, 2, requires_grad=True) 227*da0073e9SAndroid Build Coastguard Worker weight = torch.randn(1, 1, 1, 1, requires_grad=True) 228*da0073e9SAndroid Build Coastguard Worker assert_equivalence(lambda: F.conv2d(x, weight).sum().backward(), 8) 229*da0073e9SAndroid Build Coastguard Worker 230*da0073e9SAndroid Build Coastguard Worker for in_channels, out_channels, groups in [ 231*da0073e9SAndroid Build Coastguard Worker (1, 1, 1), 232*da0073e9SAndroid Build Coastguard Worker (1, 3, 1), 233*da0073e9SAndroid Build Coastguard Worker (3, 1, 1), 234*da0073e9SAndroid Build Coastguard Worker (3, 7, 1), 235*da0073e9SAndroid Build Coastguard Worker (2, 4, 2), 236*da0073e9SAndroid Build Coastguard Worker (4, 2, 2), 237*da0073e9SAndroid Build Coastguard Worker ]: 238*da0073e9SAndroid Build Coastguard Worker x = torch.rand(1, in_channels, 4, 4, requires_grad=True) 239*da0073e9SAndroid Build Coastguard Worker weight = torch.randn(out_channels, in_channels, 2, 2, requires_grad=True) 240*da0073e9SAndroid Build Coastguard Worker assert_equivalence(lambda: F.conv2d(x, weight).sum().backward()) 241*da0073e9SAndroid Build Coastguard Worker transposed_weight = torch.randn( 242*da0073e9SAndroid Build Coastguard Worker in_channels, out_channels, 2, 2, requires_grad=True 243*da0073e9SAndroid Build Coastguard Worker ) 244*da0073e9SAndroid Build Coastguard Worker assert_equivalence( 245*da0073e9SAndroid Build Coastguard Worker lambda: F.conv_transpose2d(x, transposed_weight).sum().backward() 246*da0073e9SAndroid Build Coastguard Worker ) 247*da0073e9SAndroid Build Coastguard Worker 248*da0073e9SAndroid Build Coastguard Worker @skipIfNoTorchVision 249*da0073e9SAndroid Build Coastguard Worker def test_module(self): 250*da0073e9SAndroid Build Coastguard Worker resnet18 = torchvision_models.resnet18() 251*da0073e9SAndroid Build Coastguard Worker with FlopCounterMode(resnet18) as mode: 252*da0073e9SAndroid Build Coastguard Worker a = T(1, 3, 224, 224, requires_grad=True) 253*da0073e9SAndroid Build Coastguard Worker resnet18(a).sum().backward() 254*da0073e9SAndroid Build Coastguard Worker 255*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline(get_total_flops(mode), """10884440064""") 256*da0073e9SAndroid Build Coastguard Worker layer1_conv_flops = mode.flop_counts["ResNet.layer1"][ 257*da0073e9SAndroid Build Coastguard Worker torch.ops.aten.convolution 258*da0073e9SAndroid Build Coastguard Worker ] 259*da0073e9SAndroid Build Coastguard Worker layer1_conv_back_flops = mode.flop_counts["ResNet.layer1"][ 260*da0073e9SAndroid Build Coastguard Worker torch.ops.aten.convolution_backward 261*da0073e9SAndroid Build Coastguard Worker ] 262*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline(str(layer1_conv_flops), """924844032""") 263*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline(str(layer1_conv_back_flops), """1849688064""") 264*da0073e9SAndroid Build Coastguard Worker 265*da0073e9SAndroid Build Coastguard Worker def test_conv_transpose_loop(self): 266*da0073e9SAndroid Build Coastguard Worker x = torch.rand(1, 4, 30, 2) 267*da0073e9SAndroid Build Coastguard Worker model = torch.nn.ConvTranspose2d(4, 8, (2, 2), stride=2) 268*da0073e9SAndroid Build Coastguard Worker 269*da0073e9SAndroid Build Coastguard Worker with FlopCounterMode() as mode: 270*da0073e9SAndroid Build Coastguard Worker for i in range(50): 271*da0073e9SAndroid Build Coastguard Worker out = model(x) 272*da0073e9SAndroid Build Coastguard Worker out.sum().backward() 273*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline(str(mode.get_total_flops()), """1536000""") 274*da0073e9SAndroid Build Coastguard Worker 275*da0073e9SAndroid Build Coastguard Worker def test_custom(self): 276*da0073e9SAndroid Build Coastguard Worker mode = FlopCounterMode( 277*da0073e9SAndroid Build Coastguard Worker custom_mapping={torch.ops.aten.add: lambda *args, out_shape: 5} 278*da0073e9SAndroid Build Coastguard Worker ) 279*da0073e9SAndroid Build Coastguard Worker with mode: 280*da0073e9SAndroid Build Coastguard Worker a = T(4, 5) 281*da0073e9SAndroid Build Coastguard Worker a + a 282*da0073e9SAndroid Build Coastguard Worker 283*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline(get_total_flops(mode), """5""") 284*da0073e9SAndroid Build Coastguard Worker 285*da0073e9SAndroid Build Coastguard Worker def count(*args, out_val): 286*da0073e9SAndroid Build Coastguard Worker return out_val.numel() 287*da0073e9SAndroid Build Coastguard Worker 288*da0073e9SAndroid Build Coastguard Worker count._get_raw = True 289*da0073e9SAndroid Build Coastguard Worker 290*da0073e9SAndroid Build Coastguard Worker mode = FlopCounterMode(custom_mapping={torch.ops.aten.add: count}) 291*da0073e9SAndroid Build Coastguard Worker with mode: 292*da0073e9SAndroid Build Coastguard Worker a = T(4, 5) 293*da0073e9SAndroid Build Coastguard Worker a + a 294*da0073e9SAndroid Build Coastguard Worker 295*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline(get_total_flops(mode), """20""") 296*da0073e9SAndroid Build Coastguard Worker 297*da0073e9SAndroid Build Coastguard Worker def test_noop(self): 298*da0073e9SAndroid Build Coastguard Worker with FlopCounterMode() as mode: 299*da0073e9SAndroid Build Coastguard Worker T(4, 5).cos() 300*da0073e9SAndroid Build Coastguard Worker 301*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not HAS_CUDA, "CUDA not available") 302*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf( 303*da0073e9SAndroid Build Coastguard Worker not PLATFORM_SUPPORTS_FLASH_ATTENTION 304*da0073e9SAndroid Build Coastguard Worker or not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, 305*da0073e9SAndroid Build Coastguard Worker "Does not support all SDPA backends (pre-SM80 hardware on CUDA)", 306*da0073e9SAndroid Build Coastguard Worker ) 307*da0073e9SAndroid Build Coastguard Worker def test_sdpa(self): 308*da0073e9SAndroid Build Coastguard Worker batch_size = 4 309*da0073e9SAndroid Build Coastguard Worker n_heads = 8 310*da0073e9SAndroid Build Coastguard Worker seq_len_q = 128 311*da0073e9SAndroid Build Coastguard Worker seq_len_k = 256 312*da0073e9SAndroid Build Coastguard Worker head_dim = 64 313*da0073e9SAndroid Build Coastguard Worker head_dim_v = 64 314*da0073e9SAndroid Build Coastguard Worker dtype = torch.float16 315*da0073e9SAndroid Build Coastguard Worker 316*da0073e9SAndroid Build Coastguard Worker torch.manual_seed(0) 317*da0073e9SAndroid Build Coastguard Worker 318*da0073e9SAndroid Build Coastguard Worker def get_flops( 319*da0073e9SAndroid Build Coastguard Worker batch_size, 320*da0073e9SAndroid Build Coastguard Worker n_heads, 321*da0073e9SAndroid Build Coastguard Worker seq_len_q, 322*da0073e9SAndroid Build Coastguard Worker seq_len_k, 323*da0073e9SAndroid Build Coastguard Worker head_dim, 324*da0073e9SAndroid Build Coastguard Worker head_dim_v, 325*da0073e9SAndroid Build Coastguard Worker dtype, 326*da0073e9SAndroid Build Coastguard Worker backend, 327*da0073e9SAndroid Build Coastguard Worker with_backward=False, 328*da0073e9SAndroid Build Coastguard Worker ): 329*da0073e9SAndroid Build Coastguard Worker query = torch.randn( 330*da0073e9SAndroid Build Coastguard Worker batch_size, 331*da0073e9SAndroid Build Coastguard Worker n_heads, 332*da0073e9SAndroid Build Coastguard Worker seq_len_q, 333*da0073e9SAndroid Build Coastguard Worker head_dim, 334*da0073e9SAndroid Build Coastguard Worker device="cuda", 335*da0073e9SAndroid Build Coastguard Worker dtype=dtype, 336*da0073e9SAndroid Build Coastguard Worker requires_grad=True, 337*da0073e9SAndroid Build Coastguard Worker ) 338*da0073e9SAndroid Build Coastguard Worker key = torch.randn( 339*da0073e9SAndroid Build Coastguard Worker batch_size, 340*da0073e9SAndroid Build Coastguard Worker n_heads, 341*da0073e9SAndroid Build Coastguard Worker seq_len_k, 342*da0073e9SAndroid Build Coastguard Worker head_dim, 343*da0073e9SAndroid Build Coastguard Worker device="cuda", 344*da0073e9SAndroid Build Coastguard Worker dtype=dtype, 345*da0073e9SAndroid Build Coastguard Worker requires_grad=True, 346*da0073e9SAndroid Build Coastguard Worker ) 347*da0073e9SAndroid Build Coastguard Worker value = torch.randn( 348*da0073e9SAndroid Build Coastguard Worker batch_size, 349*da0073e9SAndroid Build Coastguard Worker n_heads, 350*da0073e9SAndroid Build Coastguard Worker seq_len_k, 351*da0073e9SAndroid Build Coastguard Worker head_dim_v, 352*da0073e9SAndroid Build Coastguard Worker device="cuda", 353*da0073e9SAndroid Build Coastguard Worker dtype=dtype, 354*da0073e9SAndroid Build Coastguard Worker requires_grad=True, 355*da0073e9SAndroid Build Coastguard Worker ) 356*da0073e9SAndroid Build Coastguard Worker 357*da0073e9SAndroid Build Coastguard Worker if backend == "math": 358*da0073e9SAndroid Build Coastguard Worker backend = torch.backends.cuda.sdp_kernel( 359*da0073e9SAndroid Build Coastguard Worker enable_flash=False, enable_math=True, enable_mem_efficient=False 360*da0073e9SAndroid Build Coastguard Worker ) 361*da0073e9SAndroid Build Coastguard Worker elif backend == "flash": 362*da0073e9SAndroid Build Coastguard Worker backend = torch.backends.cuda.sdp_kernel( 363*da0073e9SAndroid Build Coastguard Worker enable_flash=True, enable_math=False, enable_mem_efficient=False 364*da0073e9SAndroid Build Coastguard Worker ) 365*da0073e9SAndroid Build Coastguard Worker elif backend == "mem_efficient": 366*da0073e9SAndroid Build Coastguard Worker backend = torch.backends.cuda.sdp_kernel( 367*da0073e9SAndroid Build Coastguard Worker enable_flash=False, enable_math=False, enable_mem_efficient=True 368*da0073e9SAndroid Build Coastguard Worker ) 369*da0073e9SAndroid Build Coastguard Worker 370*da0073e9SAndroid Build Coastguard Worker mode = FlopCounterMode() 371*da0073e9SAndroid Build Coastguard Worker with backend, mode: 372*da0073e9SAndroid Build Coastguard Worker out = F.scaled_dot_product_attention( 373*da0073e9SAndroid Build Coastguard Worker query, key, value, dropout_p=0, is_causal=True 374*da0073e9SAndroid Build Coastguard Worker ) 375*da0073e9SAndroid Build Coastguard Worker if with_backward: 376*da0073e9SAndroid Build Coastguard Worker out.sum().backward() 377*da0073e9SAndroid Build Coastguard Worker return int(get_total_flops(mode)) 378*da0073e9SAndroid Build Coastguard Worker 379*da0073e9SAndroid Build Coastguard Worker # Sets seq_len_q == seq_len_k and dim_q == dim_v 380*da0073e9SAndroid Build Coastguard Worker run_uniform_flops = functools.partial( 381*da0073e9SAndroid Build Coastguard Worker get_flops, 382*da0073e9SAndroid Build Coastguard Worker batch_size, 383*da0073e9SAndroid Build Coastguard Worker n_heads, 384*da0073e9SAndroid Build Coastguard Worker seq_len_q, 385*da0073e9SAndroid Build Coastguard Worker seq_len_q, 386*da0073e9SAndroid Build Coastguard Worker head_dim, 387*da0073e9SAndroid Build Coastguard Worker head_dim, 388*da0073e9SAndroid Build Coastguard Worker dtype, 389*da0073e9SAndroid Build Coastguard Worker ) 390*da0073e9SAndroid Build Coastguard Worker 391*da0073e9SAndroid Build Coastguard Worker flops = [ 392*da0073e9SAndroid Build Coastguard Worker run_uniform_flops(backend, with_backward=False) 393*da0073e9SAndroid Build Coastguard Worker for backend in ["math", "flash", "mem_efficient"] 394*da0073e9SAndroid Build Coastguard Worker ] 395*da0073e9SAndroid Build Coastguard Worker flops_fw_math, flops_fw_flash, flops_fw_efficient = flops 396*da0073e9SAndroid Build Coastguard Worker self.assertEqual(flops_fw_math, flops_fw_flash) 397*da0073e9SAndroid Build Coastguard Worker self.assertEqual(flops_fw_math, flops_fw_efficient) 398*da0073e9SAndroid Build Coastguard Worker 399*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline(str(flops_fw_math), """134217728""") 400*da0073e9SAndroid Build Coastguard Worker 401*da0073e9SAndroid Build Coastguard Worker flops = [ 402*da0073e9SAndroid Build Coastguard Worker run_uniform_flops(backend, with_backward=True) 403*da0073e9SAndroid Build Coastguard Worker for backend in ["math", "flash", "mem_efficient"] 404*da0073e9SAndroid Build Coastguard Worker ] 405*da0073e9SAndroid Build Coastguard Worker flops_fw_bw_math, flops_fw_bw_flash, flops_fw_bw_efficient = flops 406*da0073e9SAndroid Build Coastguard Worker self.assertEqual(flops_fw_math * 3, flops_fw_bw_math) 407*da0073e9SAndroid Build Coastguard Worker self.assertEqual(flops_fw_math * 7 // 2, flops_fw_bw_flash) 408*da0073e9SAndroid Build Coastguard Worker self.assertEqual(flops_fw_bw_flash, flops_fw_bw_efficient) 409*da0073e9SAndroid Build Coastguard Worker 410*da0073e9SAndroid Build Coastguard Worker run_nonuniform_flops = functools.partial( 411*da0073e9SAndroid Build Coastguard Worker get_flops, 412*da0073e9SAndroid Build Coastguard Worker batch_size, 413*da0073e9SAndroid Build Coastguard Worker n_heads, 414*da0073e9SAndroid Build Coastguard Worker seq_len_q, 415*da0073e9SAndroid Build Coastguard Worker seq_len_k, 416*da0073e9SAndroid Build Coastguard Worker head_dim, 417*da0073e9SAndroid Build Coastguard Worker head_dim_v, 418*da0073e9SAndroid Build Coastguard Worker dtype, 419*da0073e9SAndroid Build Coastguard Worker ) 420*da0073e9SAndroid Build Coastguard Worker # Flash does not support non-uniform attention, i.e. seq_len_q != seq_len_k or dim_q != dim_v" 421*da0073e9SAndroid Build Coastguard Worker non_uniform_backends = ["math", "mem_efficient"] 422*da0073e9SAndroid Build Coastguard Worker flops = [ 423*da0073e9SAndroid Build Coastguard Worker run_nonuniform_flops(backend, with_backward=False) 424*da0073e9SAndroid Build Coastguard Worker for backend in non_uniform_backends 425*da0073e9SAndroid Build Coastguard Worker ] 426*da0073e9SAndroid Build Coastguard Worker flops_fw_math, flops_fw_efficient = flops 427*da0073e9SAndroid Build Coastguard Worker self.assertEqual(flops_fw_math, flops_fw_efficient) 428*da0073e9SAndroid Build Coastguard Worker 429*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline(str(flops_fw_math), """268435456""") 430*da0073e9SAndroid Build Coastguard Worker 431*da0073e9SAndroid Build Coastguard Worker flops = [ 432*da0073e9SAndroid Build Coastguard Worker run_nonuniform_flops(backend, with_backward=True) 433*da0073e9SAndroid Build Coastguard Worker for backend in non_uniform_backends 434*da0073e9SAndroid Build Coastguard Worker ] 435*da0073e9SAndroid Build Coastguard Worker flops_fw_bw_math, flops_fw_bw_efficient = flops 436*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline(str(flops_fw_bw_math), """805306368""") 437*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline(str(flops_fw_bw_efficient), """939524096""") 438*da0073e9SAndroid Build Coastguard Worker 439*da0073e9SAndroid Build Coastguard Worker @skipIfRocm # Nested tensor 440*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not HAS_CUDA, "CUDA not available") 441*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf( 442*da0073e9SAndroid Build Coastguard Worker not PLATFORM_SUPPORTS_FLASH_ATTENTION 443*da0073e9SAndroid Build Coastguard Worker or not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, 444*da0073e9SAndroid Build Coastguard Worker "Does not support all SDPA backends (pre-SM80 hardware on CUDA)", 445*da0073e9SAndroid Build Coastguard Worker ) 446*da0073e9SAndroid Build Coastguard Worker def test_sdpa_nested_tensor(self): 447*da0073e9SAndroid Build Coastguard Worker def get_flops(q, k, v, backend, with_backward=False): 448*da0073e9SAndroid Build Coastguard Worker mode = FlopCounterMode() 449*da0073e9SAndroid Build Coastguard Worker 450*da0073e9SAndroid Build Coastguard Worker if backend == "math": 451*da0073e9SAndroid Build Coastguard Worker backend = torch.backends.cuda.sdp_kernel( 452*da0073e9SAndroid Build Coastguard Worker enable_flash=False, enable_math=True, enable_mem_efficient=False 453*da0073e9SAndroid Build Coastguard Worker ) 454*da0073e9SAndroid Build Coastguard Worker elif backend == "flash": 455*da0073e9SAndroid Build Coastguard Worker backend = torch.backends.cuda.sdp_kernel( 456*da0073e9SAndroid Build Coastguard Worker enable_flash=True, enable_math=False, enable_mem_efficient=False 457*da0073e9SAndroid Build Coastguard Worker ) 458*da0073e9SAndroid Build Coastguard Worker elif backend == "mem_efficient": 459*da0073e9SAndroid Build Coastguard Worker backend = torch.backends.cuda.sdp_kernel( 460*da0073e9SAndroid Build Coastguard Worker enable_flash=False, enable_math=False, enable_mem_efficient=True 461*da0073e9SAndroid Build Coastguard Worker ) 462*da0073e9SAndroid Build Coastguard Worker 463*da0073e9SAndroid Build Coastguard Worker with backend, mode: 464*da0073e9SAndroid Build Coastguard Worker out = F.scaled_dot_product_attention( 465*da0073e9SAndroid Build Coastguard Worker q, k, v, dropout_p=0, is_causal=True 466*da0073e9SAndroid Build Coastguard Worker ) 467*da0073e9SAndroid Build Coastguard Worker if with_backward: 468*da0073e9SAndroid Build Coastguard Worker if out.is_nested: 469*da0073e9SAndroid Build Coastguard Worker out.values().sum().backward() 470*da0073e9SAndroid Build Coastguard Worker else: 471*da0073e9SAndroid Build Coastguard Worker out.sum().backward() 472*da0073e9SAndroid Build Coastguard Worker 473*da0073e9SAndroid Build Coastguard Worker return int(get_total_flops(mode)) 474*da0073e9SAndroid Build Coastguard Worker 475*da0073e9SAndroid Build Coastguard Worker def get_nested_inputs( 476*da0073e9SAndroid Build Coastguard Worker batch_size, 477*da0073e9SAndroid Build Coastguard Worker n_heads, 478*da0073e9SAndroid Build Coastguard Worker max_seq_len_q, 479*da0073e9SAndroid Build Coastguard Worker max_seq_len_k, 480*da0073e9SAndroid Build Coastguard Worker head_dim, 481*da0073e9SAndroid Build Coastguard Worker head_dim_v, 482*da0073e9SAndroid Build Coastguard Worker dtype, 483*da0073e9SAndroid Build Coastguard Worker ): 484*da0073e9SAndroid Build Coastguard Worker q_lengths = torch.tensor( 485*da0073e9SAndroid Build Coastguard Worker [ 486*da0073e9SAndroid Build Coastguard Worker max_seq_len_q // 4, 487*da0073e9SAndroid Build Coastguard Worker max_seq_len_q // 4 * 2, 488*da0073e9SAndroid Build Coastguard Worker max_seq_len_q // 4 * 3, 489*da0073e9SAndroid Build Coastguard Worker max_seq_len_q // 4 * 4, 490*da0073e9SAndroid Build Coastguard Worker ] 491*da0073e9SAndroid Build Coastguard Worker ) 492*da0073e9SAndroid Build Coastguard Worker k_lengths = torch.tensor( 493*da0073e9SAndroid Build Coastguard Worker [ 494*da0073e9SAndroid Build Coastguard Worker max_seq_len_k // 4, 495*da0073e9SAndroid Build Coastguard Worker max_seq_len_k // 4 * 2, 496*da0073e9SAndroid Build Coastguard Worker max_seq_len_k // 4 * 3, 497*da0073e9SAndroid Build Coastguard Worker max_seq_len_k // 4 * 4, 498*da0073e9SAndroid Build Coastguard Worker ] 499*da0073e9SAndroid Build Coastguard Worker ) 500*da0073e9SAndroid Build Coastguard Worker q_offsets, k_offsets = ( 501*da0073e9SAndroid Build Coastguard Worker torch.cat((torch.tensor([0]), torch.cumsum(lengths, dim=0))).cuda() 502*da0073e9SAndroid Build Coastguard Worker for lengths in (q_lengths, k_lengths) 503*da0073e9SAndroid Build Coastguard Worker ) 504*da0073e9SAndroid Build Coastguard Worker q_values = torch.randn( 505*da0073e9SAndroid Build Coastguard Worker q_offsets[-1], 506*da0073e9SAndroid Build Coastguard Worker head_dim * n_heads, 507*da0073e9SAndroid Build Coastguard Worker dtype=dtype, 508*da0073e9SAndroid Build Coastguard Worker requires_grad=True, 509*da0073e9SAndroid Build Coastguard Worker device="cuda", 510*da0073e9SAndroid Build Coastguard Worker ) 511*da0073e9SAndroid Build Coastguard Worker k_values = torch.randn( 512*da0073e9SAndroid Build Coastguard Worker k_offsets[-1], 513*da0073e9SAndroid Build Coastguard Worker head_dim * n_heads, 514*da0073e9SAndroid Build Coastguard Worker dtype=dtype, 515*da0073e9SAndroid Build Coastguard Worker requires_grad=True, 516*da0073e9SAndroid Build Coastguard Worker device="cuda", 517*da0073e9SAndroid Build Coastguard Worker ) 518*da0073e9SAndroid Build Coastguard Worker v_values = torch.randn( 519*da0073e9SAndroid Build Coastguard Worker k_offsets[-1], 520*da0073e9SAndroid Build Coastguard Worker head_dim_v * n_heads, 521*da0073e9SAndroid Build Coastguard Worker dtype=dtype, 522*da0073e9SAndroid Build Coastguard Worker requires_grad=True, 523*da0073e9SAndroid Build Coastguard Worker device="cuda", 524*da0073e9SAndroid Build Coastguard Worker ) 525*da0073e9SAndroid Build Coastguard Worker 526*da0073e9SAndroid Build Coastguard Worker q = torch.nested.nested_tensor_from_jagged(q_values, q_offsets) 527*da0073e9SAndroid Build Coastguard Worker k = torch.nested.nested_tensor_from_jagged(k_values, k_offsets) 528*da0073e9SAndroid Build Coastguard Worker v = torch.nested.nested_tensor_from_jagged(v_values, k_offsets) 529*da0073e9SAndroid Build Coastguard Worker 530*da0073e9SAndroid Build Coastguard Worker q = q.view(batch_size, -1, n_heads, head_dim).transpose(1, 2) 531*da0073e9SAndroid Build Coastguard Worker k = k.view(batch_size, -1, n_heads, head_dim).transpose(1, 2) 532*da0073e9SAndroid Build Coastguard Worker v = v.view(batch_size, -1, n_heads, head_dim_v).transpose(1, 2) 533*da0073e9SAndroid Build Coastguard Worker 534*da0073e9SAndroid Build Coastguard Worker return q, k, v 535*da0073e9SAndroid Build Coastguard Worker 536*da0073e9SAndroid Build Coastguard Worker def get_dense_flops(q, k, v, backend, with_backward=False): 537*da0073e9SAndroid Build Coastguard Worker def split_tensor(x): 538*da0073e9SAndroid Build Coastguard Worker return ( 539*da0073e9SAndroid Build Coastguard Worker y.unsqueeze(0).transpose(1, 2).detach().requires_grad_(True) 540*da0073e9SAndroid Build Coastguard Worker for y in x.transpose(1, 2).unbind(0) 541*da0073e9SAndroid Build Coastguard Worker ) 542*da0073e9SAndroid Build Coastguard Worker 543*da0073e9SAndroid Build Coastguard Worker q_tensors = split_tensor(q) 544*da0073e9SAndroid Build Coastguard Worker k_tensors = split_tensor(k) 545*da0073e9SAndroid Build Coastguard Worker v_tensors = split_tensor(v) 546*da0073e9SAndroid Build Coastguard Worker 547*da0073e9SAndroid Build Coastguard Worker flops = 0 548*da0073e9SAndroid Build Coastguard Worker for q_i, k_i, v_i in zip(q_tensors, k_tensors, v_tensors): 549*da0073e9SAndroid Build Coastguard Worker flops += get_flops( 550*da0073e9SAndroid Build Coastguard Worker q_i, k_i, v_i, backend=backend, with_backward=with_backward 551*da0073e9SAndroid Build Coastguard Worker ) 552*da0073e9SAndroid Build Coastguard Worker 553*da0073e9SAndroid Build Coastguard Worker return flops 554*da0073e9SAndroid Build Coastguard Worker 555*da0073e9SAndroid Build Coastguard Worker uniform_config = { 556*da0073e9SAndroid Build Coastguard Worker "batch_size": 4, 557*da0073e9SAndroid Build Coastguard Worker "n_heads": 8, 558*da0073e9SAndroid Build Coastguard Worker "max_seq_len_q": 128, 559*da0073e9SAndroid Build Coastguard Worker "max_seq_len_k": 128, 560*da0073e9SAndroid Build Coastguard Worker "head_dim": 64, 561*da0073e9SAndroid Build Coastguard Worker "head_dim_v": 64, 562*da0073e9SAndroid Build Coastguard Worker "dtype": torch.float16, 563*da0073e9SAndroid Build Coastguard Worker } 564*da0073e9SAndroid Build Coastguard Worker 565*da0073e9SAndroid Build Coastguard Worker # max_seq_len_q != max_seq_len_k doesn't work for flash attention with dense tensors. 566*da0073e9SAndroid Build Coastguard Worker differing_config = { 567*da0073e9SAndroid Build Coastguard Worker "batch_size": 4, 568*da0073e9SAndroid Build Coastguard Worker "n_heads": 8, 569*da0073e9SAndroid Build Coastguard Worker "max_seq_len_q": 128, 570*da0073e9SAndroid Build Coastguard Worker "max_seq_len_k": 256, 571*da0073e9SAndroid Build Coastguard Worker "head_dim": 64, 572*da0073e9SAndroid Build Coastguard Worker "head_dim_v": 64, 573*da0073e9SAndroid Build Coastguard Worker "dtype": torch.float16, 574*da0073e9SAndroid Build Coastguard Worker } 575*da0073e9SAndroid Build Coastguard Worker 576*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 577*da0073e9SAndroid Build Coastguard Worker get_dense_flops( 578*da0073e9SAndroid Build Coastguard Worker *get_nested_inputs(**uniform_config), 579*da0073e9SAndroid Build Coastguard Worker backend="flash", 580*da0073e9SAndroid Build Coastguard Worker with_backward=False, 581*da0073e9SAndroid Build Coastguard Worker ), 582*da0073e9SAndroid Build Coastguard Worker get_flops( 583*da0073e9SAndroid Build Coastguard Worker *get_nested_inputs(**uniform_config), 584*da0073e9SAndroid Build Coastguard Worker backend="flash", 585*da0073e9SAndroid Build Coastguard Worker with_backward=False, 586*da0073e9SAndroid Build Coastguard Worker ), 587*da0073e9SAndroid Build Coastguard Worker ) 588*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 589*da0073e9SAndroid Build Coastguard Worker get_dense_flops( 590*da0073e9SAndroid Build Coastguard Worker *get_nested_inputs(**uniform_config), 591*da0073e9SAndroid Build Coastguard Worker backend="mem_efficient", 592*da0073e9SAndroid Build Coastguard Worker with_backward=False, 593*da0073e9SAndroid Build Coastguard Worker ), 594*da0073e9SAndroid Build Coastguard Worker get_flops( 595*da0073e9SAndroid Build Coastguard Worker *get_nested_inputs(**uniform_config), 596*da0073e9SAndroid Build Coastguard Worker backend="mem_efficient", 597*da0073e9SAndroid Build Coastguard Worker with_backward=False, 598*da0073e9SAndroid Build Coastguard Worker ), 599*da0073e9SAndroid Build Coastguard Worker ) 600*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 601*da0073e9SAndroid Build Coastguard Worker get_dense_flops( 602*da0073e9SAndroid Build Coastguard Worker *get_nested_inputs(**differing_config), 603*da0073e9SAndroid Build Coastguard Worker backend="mem_efficient", 604*da0073e9SAndroid Build Coastguard Worker with_backward=False, 605*da0073e9SAndroid Build Coastguard Worker ), 606*da0073e9SAndroid Build Coastguard Worker get_flops( 607*da0073e9SAndroid Build Coastguard Worker *get_nested_inputs(**differing_config), 608*da0073e9SAndroid Build Coastguard Worker backend="mem_efficient", 609*da0073e9SAndroid Build Coastguard Worker with_backward=False, 610*da0073e9SAndroid Build Coastguard Worker ), 611*da0073e9SAndroid Build Coastguard Worker ) 612*da0073e9SAndroid Build Coastguard Worker 613*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 614*da0073e9SAndroid Build Coastguard Worker get_dense_flops( 615*da0073e9SAndroid Build Coastguard Worker *get_nested_inputs(**uniform_config), 616*da0073e9SAndroid Build Coastguard Worker backend="flash", 617*da0073e9SAndroid Build Coastguard Worker with_backward=True, 618*da0073e9SAndroid Build Coastguard Worker ), 619*da0073e9SAndroid Build Coastguard Worker get_flops( 620*da0073e9SAndroid Build Coastguard Worker *get_nested_inputs(**uniform_config), 621*da0073e9SAndroid Build Coastguard Worker backend="flash", 622*da0073e9SAndroid Build Coastguard Worker with_backward=True, 623*da0073e9SAndroid Build Coastguard Worker ), 624*da0073e9SAndroid Build Coastguard Worker ) 625*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 626*da0073e9SAndroid Build Coastguard Worker get_dense_flops( 627*da0073e9SAndroid Build Coastguard Worker *get_nested_inputs(**uniform_config), 628*da0073e9SAndroid Build Coastguard Worker backend="mem_efficient", 629*da0073e9SAndroid Build Coastguard Worker with_backward=True, 630*da0073e9SAndroid Build Coastguard Worker ), 631*da0073e9SAndroid Build Coastguard Worker get_flops( 632*da0073e9SAndroid Build Coastguard Worker *get_nested_inputs(**uniform_config), 633*da0073e9SAndroid Build Coastguard Worker backend="mem_efficient", 634*da0073e9SAndroid Build Coastguard Worker with_backward=True, 635*da0073e9SAndroid Build Coastguard Worker ), 636*da0073e9SAndroid Build Coastguard Worker ) 637*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 638*da0073e9SAndroid Build Coastguard Worker get_dense_flops( 639*da0073e9SAndroid Build Coastguard Worker *get_nested_inputs(**differing_config), 640*da0073e9SAndroid Build Coastguard Worker backend="mem_efficient", 641*da0073e9SAndroid Build Coastguard Worker with_backward=True, 642*da0073e9SAndroid Build Coastguard Worker ), 643*da0073e9SAndroid Build Coastguard Worker get_flops( 644*da0073e9SAndroid Build Coastguard Worker *get_nested_inputs(**differing_config), 645*da0073e9SAndroid Build Coastguard Worker backend="mem_efficient", 646*da0073e9SAndroid Build Coastguard Worker with_backward=True, 647*da0073e9SAndroid Build Coastguard Worker ), 648*da0073e9SAndroid Build Coastguard Worker ) 649*da0073e9SAndroid Build Coastguard Worker 650*da0073e9SAndroid Build Coastguard Worker @skipIfRocm # Nested tensor 651*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not HAS_CUDA, "CUDA not available") 652*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf( 653*da0073e9SAndroid Build Coastguard Worker not PLATFORM_SUPPORTS_FLASH_ATTENTION, 654*da0073e9SAndroid Build Coastguard Worker "Does not support all SDPA backends (pre-SM80 hardware on CUDA)", 655*da0073e9SAndroid Build Coastguard Worker ) 656*da0073e9SAndroid Build Coastguard Worker def test_nested_attention_fake_tensors(self): 657*da0073e9SAndroid Build Coastguard Worker x = torch.randn(123, 4, 16, device="cuda", dtype=torch.bfloat16) 658*da0073e9SAndroid Build Coastguard Worker offsets = torch.tensor([0, 30, 60, 90, 123], device="cuda") 659*da0073e9SAndroid Build Coastguard Worker max_seqlen = 40 660*da0073e9SAndroid Build Coastguard Worker with FakeTensorMode() as fake_mode: 661*da0073e9SAndroid Build Coastguard Worker fake_x = fake_mode.from_tensor(x) 662*da0073e9SAndroid Build Coastguard Worker fake_offsets = fake_mode.from_tensor(offsets) 663*da0073e9SAndroid Build Coastguard Worker 664*da0073e9SAndroid Build Coastguard Worker with FlopCounterMode() as fake_flop_counter_mode: 665*da0073e9SAndroid Build Coastguard Worker torch.ops.aten._flash_attention_forward( 666*da0073e9SAndroid Build Coastguard Worker fake_x, 667*da0073e9SAndroid Build Coastguard Worker fake_x, 668*da0073e9SAndroid Build Coastguard Worker fake_x, 669*da0073e9SAndroid Build Coastguard Worker fake_offsets, 670*da0073e9SAndroid Build Coastguard Worker fake_offsets, 671*da0073e9SAndroid Build Coastguard Worker max_seqlen, 672*da0073e9SAndroid Build Coastguard Worker max_seqlen, 673*da0073e9SAndroid Build Coastguard Worker 0.0, 674*da0073e9SAndroid Build Coastguard Worker False, 675*da0073e9SAndroid Build Coastguard Worker False, 676*da0073e9SAndroid Build Coastguard Worker ) 677*da0073e9SAndroid Build Coastguard Worker 678*da0073e9SAndroid Build Coastguard Worker dense_x = torch.randn(4, 40, 4, 16, dtype=torch.bfloat16, device="cuda").transpose(1, 2) 679*da0073e9SAndroid Build Coastguard Worker 680*da0073e9SAndroid Build Coastguard Worker with FlopCounterMode() as real_flop_counter_mode: 681*da0073e9SAndroid Build Coastguard Worker torch.ops.aten._flash_attention_forward( 682*da0073e9SAndroid Build Coastguard Worker dense_x, 683*da0073e9SAndroid Build Coastguard Worker dense_x, 684*da0073e9SAndroid Build Coastguard Worker dense_x, 685*da0073e9SAndroid Build Coastguard Worker None, 686*da0073e9SAndroid Build Coastguard Worker None, 687*da0073e9SAndroid Build Coastguard Worker max_seqlen, 688*da0073e9SAndroid Build Coastguard Worker max_seqlen, 689*da0073e9SAndroid Build Coastguard Worker 0.0, 690*da0073e9SAndroid Build Coastguard Worker False, 691*da0073e9SAndroid Build Coastguard Worker False, 692*da0073e9SAndroid Build Coastguard Worker ) 693*da0073e9SAndroid Build Coastguard Worker 694*da0073e9SAndroid Build Coastguard Worker self.assertEqual(int(get_total_flops(fake_flop_counter_mode)), int(get_total_flops(real_flop_counter_mode))) 695*da0073e9SAndroid Build Coastguard Worker 696*da0073e9SAndroid Build Coastguard Worker 697*da0073e9SAndroid Build Coastguard Worker def test_addmm_out(self): 698*da0073e9SAndroid Build Coastguard Worker def f(x): 699*da0073e9SAndroid Build Coastguard Worker y = torch.zeros(10, 10) 700*da0073e9SAndroid Build Coastguard Worker return torch.mm(x, x, out=y) 701*da0073e9SAndroid Build Coastguard Worker 702*da0073e9SAndroid Build Coastguard Worker with FlopCounterMode() as mode: 703*da0073e9SAndroid Build Coastguard Worker f(torch.randn(10, 10)) 704*da0073e9SAndroid Build Coastguard Worker 705*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline(get_total_flops(mode), """2000""") 706*da0073e9SAndroid Build Coastguard Worker 707*da0073e9SAndroid Build Coastguard Worker def test_hook_registration(self): 708*da0073e9SAndroid Build Coastguard Worker model = torch.nn.Linear(100, 100) 709*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3, 100) 710*da0073e9SAndroid Build Coastguard Worker 711*da0073e9SAndroid Build Coastguard Worker with FlopCounterMode() as mode: 712*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(torch.nn.modules.module._global_forward_pre_hooks), 1) 713*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(torch.nn.modules.module._global_forward_hooks), 1) 714*da0073e9SAndroid Build Coastguard Worker model(x).sum().backward() 715*da0073e9SAndroid Build Coastguard Worker 716*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(torch.nn.modules.module._global_forward_pre_hooks), 0) 717*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(torch.nn.modules.module._global_forward_hooks), 0) 718*da0073e9SAndroid Build Coastguard Worker 719*da0073e9SAndroid Build Coastguard Worker def test_pytrees(self): 720*da0073e9SAndroid Build Coastguard Worker class Foo(torch.nn.Module): 721*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 722*da0073e9SAndroid Build Coastguard Worker x = x["a"].relu_() 723*da0073e9SAndroid Build Coastguard Worker return {"a": torch.mm(x, x)} 724*da0073e9SAndroid Build Coastguard Worker 725*da0073e9SAndroid Build Coastguard Worker class Mod(torch.nn.Module): 726*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 727*da0073e9SAndroid Build Coastguard Worker super().__init__() 728*da0073e9SAndroid Build Coastguard Worker self.a = Foo() 729*da0073e9SAndroid Build Coastguard Worker self.b = Foo() 730*da0073e9SAndroid Build Coastguard Worker 731*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 732*da0073e9SAndroid Build Coastguard Worker return self.b(self.a(x)) 733*da0073e9SAndroid Build Coastguard Worker 734*da0073e9SAndroid Build Coastguard Worker mod = Mod() 735*da0073e9SAndroid Build Coastguard Worker with FlopCounterMode() as mode: 736*da0073e9SAndroid Build Coastguard Worker mod({"a": torch.randn(10, 10, requires_grad=True).clone()})[ 737*da0073e9SAndroid Build Coastguard Worker "a" 738*da0073e9SAndroid Build Coastguard Worker ].sum().backward() 739*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 740*da0073e9SAndroid Build Coastguard Worker (mode.flop_counts["Mod"][torch.ops.aten.mm]), """12000""" 741*da0073e9SAndroid Build Coastguard Worker ) 742*da0073e9SAndroid Build Coastguard Worker 743*da0073e9SAndroid Build Coastguard Worker class Mod2(torch.nn.Module): 744*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 745*da0073e9SAndroid Build Coastguard Worker return (torch.mm(x, x),) 746*da0073e9SAndroid Build Coastguard Worker 747*da0073e9SAndroid Build Coastguard Worker mod = Mod2() 748*da0073e9SAndroid Build Coastguard Worker with FlopCounterMode() as mode: 749*da0073e9SAndroid Build Coastguard Worker mod(torch.randn(10, 10, requires_grad=True))[0].sum().backward() 750*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 751*da0073e9SAndroid Build Coastguard Worker (mode.flop_counts["Mod2"][torch.ops.aten.mm]), """6000""" 752*da0073e9SAndroid Build Coastguard Worker ) 753*da0073e9SAndroid Build Coastguard Worker 754*da0073e9SAndroid Build Coastguard Worker def test_warning(self): 755*da0073e9SAndroid Build Coastguard Worker mod = torch.nn.Linear(2, 2) 756*da0073e9SAndroid Build Coastguard Worker with self.assertWarnsRegex(UserWarning, "not needed"): 757*da0073e9SAndroid Build Coastguard Worker FlopCounterMode(mod) 758*da0073e9SAndroid Build Coastguard Worker 759*da0073e9SAndroid Build Coastguard Worker def test_custom_op(self): 760*da0073e9SAndroid Build Coastguard Worker from torch.utils.flop_counter import FlopCounterMode, register_flop_formula 761*da0073e9SAndroid Build Coastguard Worker 762*da0073e9SAndroid Build Coastguard Worker @torch.library.custom_op("mylib::foo", mutates_args=()) 763*da0073e9SAndroid Build Coastguard Worker def foo(x: torch.Tensor) -> torch.Tensor: 764*da0073e9SAndroid Build Coastguard Worker return x.sin() 765*da0073e9SAndroid Build Coastguard Worker 766*da0073e9SAndroid Build Coastguard Worker called = 0 767*da0073e9SAndroid Build Coastguard Worker 768*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, "expected each target to be OpOverloadPacket"): 769*da0073e9SAndroid Build Coastguard Worker register_flop_formula(torch.ops.mylib.foo.default)(lambda x: x) 770*da0073e9SAndroid Build Coastguard Worker 771*da0073e9SAndroid Build Coastguard Worker @register_flop_formula(torch.ops.mylib.foo) 772*da0073e9SAndroid Build Coastguard Worker def formula(*args, **kwargs): 773*da0073e9SAndroid Build Coastguard Worker nonlocal called 774*da0073e9SAndroid Build Coastguard Worker called += 1 775*da0073e9SAndroid Build Coastguard Worker return 9001 776*da0073e9SAndroid Build Coastguard Worker 777*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3) 778*da0073e9SAndroid Build Coastguard Worker with FlopCounterMode(display=False) as mode: 779*da0073e9SAndroid Build Coastguard Worker y = foo(x) 780*da0073e9SAndroid Build Coastguard Worker 781*da0073e9SAndroid Build Coastguard Worker self.assertEqual(called, 1) 782*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline(get_total_flops(mode), """9001""") 783*da0073e9SAndroid Build Coastguard Worker 784*da0073e9SAndroid Build Coastguard Worker 785*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 786*da0073e9SAndroid Build Coastguard Worker run_tests() 787