1# Owner(s): ["oncall: pt2"] 2import random 3import unittest 4from math import prod 5 6import torch 7import torch._functorch.config as config 8from torch.testing._internal.common_utils import run_tests, TEST_WITH_ROCM, TestCase 9from torch.testing._internal.inductor_utils import HAS_CUDA 10from torch.utils._triton import has_triton 11from torch.utils.flop_counter import FlopCounterMode, register_flop_formula 12 13 14if has_triton(): 15 # note: if we only import triton in the test, the test fails: 16 # def relu_kernel_(inp_ptr, out_ptr, sz, BLOCK_SIZE: tl.constexpr): 17 # NameError('tl is not defined') 18 import triton 19 import triton.language as tl 20 21 22def compile_with_ac(f, memory_budget): 23 return torch.compile(f, backend="aot_eager_decomp_partition") 24 25 26def get_act_mem(f): 27 out = f() 28 out.backward() 29 start_mem = torch.cuda.memory_stats()["requested_bytes.all.current"] 30 out = f() 31 cur_mem = torch.cuda.memory_stats()["requested_bytes.all.current"] 32 act_mem = (cur_mem - start_mem) / (1024 * 1024) 33 out.backward() 34 return act_mem 35 36 37def get_bw_flops(f): 38 # Normalized so that a 512 square matmul returns 1 39 f().backward() 40 out = f() 41 with FlopCounterMode(display=False) as mode: 42 out.backward() 43 return mode.get_total_flops() / (512**3 * 2) 44 45 46def create_pair(B_I, O): 47 # results in B_I * O memory, requires B_I * B_I * O flops 48 # arithmetic intensity of B_I 49 x = torch.randn(B_I * 512, B_I * 512, requires_grad=True) 50 w = torch.randn(B_I * 512, O * 512, requires_grad=True) 51 return x, w 52 53 54def get_mem_and_flops(f, memory_budget=None): 55 # Returns megabytes rounded to 1 decimal point and FLOPs 56 # Note that each value of size (512, 512, torch.float32) is 1 MiB 57 torch._dynamo.reset() 58 with config.patch(activation_memory_budget=memory_budget): 59 if memory_budget is not None: 60 f = torch.compile(f, backend="aot_eager_decomp_partition") 61 62 # We round this to nearest 10th of a megabyte. 63 return round(get_act_mem(f), 1), get_bw_flops(f) 64 65 66class MemoryBudgetTest(TestCase): 67 def setUp(self): 68 super().setUp() 69 torch.set_default_device("cuda") 70 71 def test_rematerializes_cheap(self): 72 def f(x, w): 73 x = x.cos() 74 x = torch.mm(x, w) 75 return x.sum() 76 77 x = torch.randn(512, 512, requires_grad=True) 78 w = torch.randn(512, 512, requires_grad=True) 79 80 def call(): 81 return f(x, w) 82 83 eager_mem, eager_flops = get_mem_and_flops(call) 84 self.assertEqual(eager_mem, 1.0) 85 mem_10, flops_10 = get_mem_and_flops(call, memory_budget=1.0) 86 # Recomputing `.cos()` is not free here. 87 self.assertEqual(mem_10, 1.0) 88 self.assertEqual(eager_flops, flops_10) 89 mem_5, flops_5 = get_mem_and_flops(call, memory_budget=0.5) 90 # We can just recompute `x.cos()` here to only depend on the inputs 91 self.assertEqual(mem_5, 0.0) 92 self.assertEqual(flops_5, eager_flops) 93 94 def test_matmul_even_chain(self): 95 def f(x, ws): 96 x = x.cos() 97 for w in ws: 98 x = torch.mm(x, w).cos() 99 return x.sum() 100 101 x = torch.randn(512, 512, requires_grad=True) 102 ws = [torch.randn(512, 512, requires_grad=True) for _ in range(5)] 103 104 def call(): 105 return f(x, ws) 106 107 eager_mem, eager_flops = get_mem_and_flops(call) 108 for budget in range(0, 11): 109 mem, flops = get_mem_and_flops(call, memory_budget=budget / 10) 110 if budget <= 5: 111 # We start saving the matmuls 112 self.assertEqual(mem, budget) 113 self.assertEqual(flops, eager_flops + (5 - budget)) 114 elif budget < 10: 115 # We're only recomputing the `cos` operations 116 self.assertEqual(mem, 5.0) 117 self.assertEqual(flops, eager_flops) 118 elif budget == 10: 119 self.assertEqual(mem, 10.0) 120 self.assertEqual(flops, eager_flops) 121 122 def test_matmul_uneven_chain(self): 123 # This function is constructed so that we are saving one input of size 124 # [512, in_dim] for each w 125 # In addition, every matmul has a same ratio of compute to "memory 126 # saved", so this test is essentially testing our knapsack solving 127 128 def f(x, ws): 129 xs = [torch.mm(x, w).cos() for w in ws] 130 return sum(x.sum() for x in xs) 131 132 x = torch.randn(512, 512, requires_grad=True) 133 134 def make_weights(w_shapes): 135 ws = [] 136 for idx, dim in enumerate(w_shapes): 137 ws.append(torch.randn(512, dim * 512, requires_grad=True)) 138 return ws 139 140 def make_weights_chain(w_shapes): 141 ws = [] 142 for idx, _ in enumerate(w_shapes): 143 old_dim = 512 if idx == 0 else w_shapes[idx - 1] * 512 144 new_dim = w_shapes[idx] * 512 145 ws.append(torch.randn(old_dim, new_dim, requires_grad=True)) 146 return ws 147 148 weight_configs = [ 149 ( 150 [11, 3, 4, 2], 151 [ 152 18, # 11 + 4 + 3 153 17, # 11 + 4 + 2 154 16, # 11 + 3 + 2 155 15, # 11 + 4 156 14, # 11 + 3 157 13, # 11 + 2 158 11, # 11 + 2 159 7, # 4 + 3 160 6, # 4 + 2 161 5, # 3 + 2 162 ], 163 ), 164 ( 165 [3, 5, 11, 17, 14], 166 [ 167 42, # 17 + 14 + 9 168 30, # 11 + 15 + 5 169 19, # 11 + 5 + 3 170 8, # 5 + 3 171 3, # 3 172 ], 173 ), 174 ] 175 random.seed(0) 176 random_arr = [random.randint(0, 50) for _ in range(10)] 177 exact_sums = [] 178 for i in range(10): 179 random.shuffle(random_arr) 180 exact_sums.append(sum(random_arr[:i])) 181 weight_configs.append((random_arr, exact_sums)) 182 183 for weight_shapes, exact_solves in weight_configs: 184 ws = make_weights(weight_shapes) 185 186 def call(): 187 return f(x, ws) 188 189 eager_mem, eager_flops = get_mem_and_flops(call) 190 total_mem = sum(weight_shapes) 191 self.assertEqual(eager_mem, sum(weight_shapes)) 192 for mem_achieved in exact_solves: 193 mem, _ = get_mem_and_flops(call, memory_budget=mem_achieved / total_mem) 194 self.assertEqual(mem, mem_achieved) 195 196 # needs CUDA, but this test file all needs CUDA. 197 @unittest.skipIf(not has_triton(), "test needs triton") 198 def test_custom_triton_kernel(self): 199 @triton.jit 200 def relu_kernel_(inp_ptr, out_ptr, sz, BLOCK_SIZE: tl.constexpr): 201 pid = tl.program_id(0) 202 block = tl.arange(0, BLOCK_SIZE) + pid * BLOCK_SIZE 203 msk = block < sz 204 inp = tl.load(inp_ptr + block, mask=msk) 205 relu = tl.where(inp < 0, 0, inp) 206 tl.store(out_ptr + block, relu, mask=msk) 207 208 @torch._library.triton_op("testac::triton_relu", mutates_args=()) 209 def triton_relu(x: torch.Tensor) -> torch.Tensor: 210 y = torch.empty_like(x) 211 sz = y.numel() 212 BLOCK_SIZE = 256 213 grid = (triton.cdiv(sz, BLOCK_SIZE),) 214 torch._library.capture_triton(relu_kernel_)[grid](x, y, sz, BLOCK_SIZE) 215 return y 216 217 @torch._library.triton_op("testac::triton_relu_backward", mutates_args=()) 218 def triton_relu_backward(grad_out: torch.Tensor) -> torch.Tensor: 219 grad_x = torch.empty_like(grad_out) 220 sz = grad_out.numel() 221 BLOCK_SIZE = 256 222 grid = (triton.cdiv(sz, BLOCK_SIZE),) 223 # I know this is wrong, but whatever.. 224 torch._library.capture_triton(relu_kernel_)[grid]( 225 grad_out, grad_x, sz, BLOCK_SIZE 226 ) 227 return grad_x 228 229 def _triton_relu_backward(ctx, grad_out: torch.Tensor) -> torch.Tensor: 230 return triton_relu_backward(grad_out) 231 232 def _triton_relu_setup_context(ctx, inputs, output): 233 pass 234 235 triton_relu.register_autograd( 236 _triton_relu_backward, 237 setup_context=_triton_relu_setup_context, 238 ) 239 240 @register_flop_formula( 241 [torch.ops.testac.triton_relu, torch.ops.testac.triton_relu_backward] 242 ) 243 def triton_relu_flops(inp_shape, *args, **kwargs): 244 return prod(inp_shape) 245 246 def f(x, ws): 247 x = torch.ops.testac.triton_relu(x) 248 for w in ws: 249 x = torch.ops.testac.triton_relu(torch.mm(x, w)) 250 return x.sum() 251 252 x = torch.randn(512, 512, requires_grad=True, device="cuda") 253 ws = [ 254 torch.randn(512, 512, requires_grad=True, device="cuda") for _ in range(5) 255 ] 256 257 def call(): 258 return f(x, ws) 259 260 expected = call() 261 for budget in range(0, 11): 262 memory_budget = budget / 10 263 torch._dynamo.reset() 264 with config.patch(activation_memory_budget=memory_budget): 265 if memory_budget is not None: 266 f_compile = torch.compile( 267 call, backend="aot_eager_decomp_partition" 268 ) 269 270 self.assertEqual(expected, f_compile()) 271 272 def test_prioritize_cheaper_matmul(self): 273 def f(xs, ws): 274 xs = [torch.mm(x, w).cos() for x, w in zip(xs, ws)] 275 return sum(x.sum() for x in xs) 276 277 x1, w1 = create_pair(1, 4) 278 x2, w2 = create_pair(2, 2) 279 280 def call(): 281 return f([x1, x2], [w1, w2]) 282 283 eager_mem, eager_flops = get_mem_and_flops(call) 284 self.assertEqual(eager_mem, 8) 285 self.assertEqual(eager_flops, 24) 286 comp_mem, comp_flops = get_mem_and_flops(call, memory_budget=0.5) 287 self.assertEqual(comp_mem, 4) 288 # We are recomputing x1 @ w1 here! 289 self.assertEqual(comp_flops, eager_flops + 4) 290 291 @config.patch(activation_memory_budget_runtime_estimator="profile") 292 def test_profile(self): 293 def f(x, ws): 294 x = x.cos() 295 for w in ws: 296 x = torch.mm(x, w).cos() 297 return x.sum() 298 299 x = torch.randn(512, 512, requires_grad=True) 300 ws = [torch.randn(512, 512, requires_grad=True) for _ in range(5)] 301 302 def call(): 303 return f(x, ws) 304 305 eager_mem, eager_flops = get_mem_and_flops(call) 306 mem, flops = get_mem_and_flops(call, memory_budget=0.2) 307 # We start saving the matmuls 308 self.assertEqual(mem, 2) 309 self.assertEqual(flops, eager_flops + 3) 310 311 def test_prioritize_cheaper_matmul2(self): 312 def f(xs, ws): 313 xs = [torch.mm(x, w).cos() for x, w in zip(xs, ws)] 314 return sum(x.sum() for x in xs) 315 316 data = [(4, 4), (6, 2), (2, 6)] 317 xs, ws = zip(*[create_pair(a, b) for a, b in data]) 318 319 def call(): 320 return f(xs, ws) 321 322 eager_mem, eager_flops = get_mem_and_flops(call) 323 self.assertEqual(eager_mem, 40) 324 self.assertEqual(eager_flops, 320) 325 mem, flops = get_mem_and_flops(call, memory_budget=28 / eager_mem) 326 # Save w1 and w2 327 self.assertEqual(mem, 28) 328 # We're recomputing w3 (the cheap one!) 329 self.assertEqual(flops - eager_flops, 2 * 2 * 6) 330 mem, flops = get_mem_and_flops(call, memory_budget=16 / eager_mem) 331 # Save w2. Note that even though saving w1 gets us closer to our memory 332 # limit, w2 is actually *more* FLOPs than w1! 333 self.assertEqual(mem, 12) 334 self.assertEqual(flops - eager_flops, 2 * 2 * 6 + 4 * 4 * 4) 335 336 def test_attention_vs_linear(self): 337 def f(x, w): 338 orig_shape = x.shape 339 x = x.reshape(1, 1, x.shape[0], x.shape[1]) 340 # I know this isn't technically right lol 341 x = torch.nn.functional.scaled_dot_product_attention( 342 x, x, x, is_causal=False 343 ).reshape(*orig_shape) 344 x = torch.mm(x, w) 345 x = x.cos() 346 return x.sum() 347 348 def try_seq_length(S, D, expected_recompute): 349 x = torch.randn(S * 512, D * 512, requires_grad=True) 350 w = torch.randn(D * 512, D * 512, requires_grad=True) 351 352 def call(): 353 return f(x, w) 354 355 with FlopCounterMode(display=False) as mode: 356 call() 357 mm_flops = mode.get_flop_counts()["Global"][torch.ops.aten.mm] 358 attn_flops = mode.get_total_flops() - mm_flops 359 mm_flops /= 512**3 * 2 360 attn_flops /= 512**3 * 2 361 362 eager_mem, eager_flops = get_mem_and_flops(call) 363 self.assertEqual(eager_mem, S * D * 2) 364 365 mem, flops = get_mem_and_flops( 366 call, memory_budget=0.6 367 ) # Force it to recompute one of mm or attn 368 self.assertEqual(mem, S * D) 369 if expected_recompute == "attn": 370 expected_flops = attn_flops 371 else: 372 expected_flops = mm_flops 373 self.assertEqual(flops - eager_flops, expected_flops) 374 375 # General behind this test is that if sequence length * 2 > D, then 376 # attention is more expensive than the linear. 377 try_seq_length(1, 1, "mm") 378 try_seq_length(1, 3, "attn") 379 try_seq_length(2, 2, "mm") 380 try_seq_length(2, 1, "mm") 381 try_seq_length(2, 5, "attn") 382 try_seq_length(4, 7, "mm") 383 try_seq_length(4, 9, "attn") 384 385 386if __name__ == "__main__": 387 # I'm using the cuda memory allocator to verify memory allocations 388 if HAS_CUDA and not TEST_WITH_ROCM: 389 run_tests() 390