xref: /aosp_15_r20/external/pytorch/test/functorch/test_ac.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: pt2"]
2import random
3import unittest
4from math import prod
5
6import torch
7import torch._functorch.config as config
8from torch.testing._internal.common_utils import run_tests, TEST_WITH_ROCM, TestCase
9from torch.testing._internal.inductor_utils import HAS_CUDA
10from torch.utils._triton import has_triton
11from torch.utils.flop_counter import FlopCounterMode, register_flop_formula
12
13
14if has_triton():
15    # note: if we only import triton in the test, the test fails:
16    # def relu_kernel_(inp_ptr, out_ptr, sz, BLOCK_SIZE: tl.constexpr):
17    # NameError('tl is not defined')
18    import triton
19    import triton.language as tl
20
21
22def compile_with_ac(f, memory_budget):
23    return torch.compile(f, backend="aot_eager_decomp_partition")
24
25
26def get_act_mem(f):
27    out = f()
28    out.backward()
29    start_mem = torch.cuda.memory_stats()["requested_bytes.all.current"]
30    out = f()
31    cur_mem = torch.cuda.memory_stats()["requested_bytes.all.current"]
32    act_mem = (cur_mem - start_mem) / (1024 * 1024)
33    out.backward()
34    return act_mem
35
36
37def get_bw_flops(f):
38    # Normalized so that a 512 square matmul returns 1
39    f().backward()
40    out = f()
41    with FlopCounterMode(display=False) as mode:
42        out.backward()
43    return mode.get_total_flops() / (512**3 * 2)
44
45
46def create_pair(B_I, O):
47    # results in B_I * O memory, requires B_I * B_I * O flops
48    # arithmetic intensity of B_I
49    x = torch.randn(B_I * 512, B_I * 512, requires_grad=True)
50    w = torch.randn(B_I * 512, O * 512, requires_grad=True)
51    return x, w
52
53
54def get_mem_and_flops(f, memory_budget=None):
55    # Returns megabytes rounded to 1 decimal point and FLOPs
56    # Note that each value of size (512, 512, torch.float32) is 1 MiB
57    torch._dynamo.reset()
58    with config.patch(activation_memory_budget=memory_budget):
59        if memory_budget is not None:
60            f = torch.compile(f, backend="aot_eager_decomp_partition")
61
62        # We round this to nearest 10th of a megabyte.
63        return round(get_act_mem(f), 1), get_bw_flops(f)
64
65
66class MemoryBudgetTest(TestCase):
67    def setUp(self):
68        super().setUp()
69        torch.set_default_device("cuda")
70
71    def test_rematerializes_cheap(self):
72        def f(x, w):
73            x = x.cos()
74            x = torch.mm(x, w)
75            return x.sum()
76
77        x = torch.randn(512, 512, requires_grad=True)
78        w = torch.randn(512, 512, requires_grad=True)
79
80        def call():
81            return f(x, w)
82
83        eager_mem, eager_flops = get_mem_and_flops(call)
84        self.assertEqual(eager_mem, 1.0)
85        mem_10, flops_10 = get_mem_and_flops(call, memory_budget=1.0)
86        # Recomputing `.cos()` is not free here.
87        self.assertEqual(mem_10, 1.0)
88        self.assertEqual(eager_flops, flops_10)
89        mem_5, flops_5 = get_mem_and_flops(call, memory_budget=0.5)
90        # We can just recompute `x.cos()` here to only depend on the inputs
91        self.assertEqual(mem_5, 0.0)
92        self.assertEqual(flops_5, eager_flops)
93
94    def test_matmul_even_chain(self):
95        def f(x, ws):
96            x = x.cos()
97            for w in ws:
98                x = torch.mm(x, w).cos()
99            return x.sum()
100
101        x = torch.randn(512, 512, requires_grad=True)
102        ws = [torch.randn(512, 512, requires_grad=True) for _ in range(5)]
103
104        def call():
105            return f(x, ws)
106
107        eager_mem, eager_flops = get_mem_and_flops(call)
108        for budget in range(0, 11):
109            mem, flops = get_mem_and_flops(call, memory_budget=budget / 10)
110            if budget <= 5:
111                # We start saving the matmuls
112                self.assertEqual(mem, budget)
113                self.assertEqual(flops, eager_flops + (5 - budget))
114            elif budget < 10:
115                # We're only recomputing the `cos` operations
116                self.assertEqual(mem, 5.0)
117                self.assertEqual(flops, eager_flops)
118            elif budget == 10:
119                self.assertEqual(mem, 10.0)
120                self.assertEqual(flops, eager_flops)
121
122    def test_matmul_uneven_chain(self):
123        # This function is constructed so that we are saving one input of size
124        # [512, in_dim] for each w
125        # In addition, every matmul has a same ratio of compute to "memory
126        # saved", so this test is essentially testing our knapsack solving
127
128        def f(x, ws):
129            xs = [torch.mm(x, w).cos() for w in ws]
130            return sum(x.sum() for x in xs)
131
132        x = torch.randn(512, 512, requires_grad=True)
133
134        def make_weights(w_shapes):
135            ws = []
136            for idx, dim in enumerate(w_shapes):
137                ws.append(torch.randn(512, dim * 512, requires_grad=True))
138            return ws
139
140        def make_weights_chain(w_shapes):
141            ws = []
142            for idx, _ in enumerate(w_shapes):
143                old_dim = 512 if idx == 0 else w_shapes[idx - 1] * 512
144                new_dim = w_shapes[idx] * 512
145                ws.append(torch.randn(old_dim, new_dim, requires_grad=True))
146            return ws
147
148        weight_configs = [
149            (
150                [11, 3, 4, 2],
151                [
152                    18,  # 11 + 4 + 3
153                    17,  # 11 + 4 + 2
154                    16,  # 11 + 3 + 2
155                    15,  # 11 + 4
156                    14,  # 11 + 3
157                    13,  # 11 + 2
158                    11,  # 11 + 2
159                    7,  # 4 + 3
160                    6,  # 4 + 2
161                    5,  # 3 + 2
162                ],
163            ),
164            (
165                [3, 5, 11, 17, 14],
166                [
167                    42,  # 17 + 14 + 9
168                    30,  # 11 + 15 + 5
169                    19,  # 11 + 5 + 3
170                    8,  # 5 + 3
171                    3,  # 3
172                ],
173            ),
174        ]
175        random.seed(0)
176        random_arr = [random.randint(0, 50) for _ in range(10)]
177        exact_sums = []
178        for i in range(10):
179            random.shuffle(random_arr)
180            exact_sums.append(sum(random_arr[:i]))
181        weight_configs.append((random_arr, exact_sums))
182
183        for weight_shapes, exact_solves in weight_configs:
184            ws = make_weights(weight_shapes)
185
186            def call():
187                return f(x, ws)
188
189            eager_mem, eager_flops = get_mem_and_flops(call)
190            total_mem = sum(weight_shapes)
191            self.assertEqual(eager_mem, sum(weight_shapes))
192            for mem_achieved in exact_solves:
193                mem, _ = get_mem_and_flops(call, memory_budget=mem_achieved / total_mem)
194                self.assertEqual(mem, mem_achieved)
195
196    # needs CUDA, but this test file all needs CUDA.
197    @unittest.skipIf(not has_triton(), "test needs triton")
198    def test_custom_triton_kernel(self):
199        @triton.jit
200        def relu_kernel_(inp_ptr, out_ptr, sz, BLOCK_SIZE: tl.constexpr):
201            pid = tl.program_id(0)
202            block = tl.arange(0, BLOCK_SIZE) + pid * BLOCK_SIZE
203            msk = block < sz
204            inp = tl.load(inp_ptr + block, mask=msk)
205            relu = tl.where(inp < 0, 0, inp)
206            tl.store(out_ptr + block, relu, mask=msk)
207
208        @torch._library.triton_op("testac::triton_relu", mutates_args=())
209        def triton_relu(x: torch.Tensor) -> torch.Tensor:
210            y = torch.empty_like(x)
211            sz = y.numel()
212            BLOCK_SIZE = 256
213            grid = (triton.cdiv(sz, BLOCK_SIZE),)
214            torch._library.capture_triton(relu_kernel_)[grid](x, y, sz, BLOCK_SIZE)
215            return y
216
217        @torch._library.triton_op("testac::triton_relu_backward", mutates_args=())
218        def triton_relu_backward(grad_out: torch.Tensor) -> torch.Tensor:
219            grad_x = torch.empty_like(grad_out)
220            sz = grad_out.numel()
221            BLOCK_SIZE = 256
222            grid = (triton.cdiv(sz, BLOCK_SIZE),)
223            # I know this is wrong, but whatever..
224            torch._library.capture_triton(relu_kernel_)[grid](
225                grad_out, grad_x, sz, BLOCK_SIZE
226            )
227            return grad_x
228
229        def _triton_relu_backward(ctx, grad_out: torch.Tensor) -> torch.Tensor:
230            return triton_relu_backward(grad_out)
231
232        def _triton_relu_setup_context(ctx, inputs, output):
233            pass
234
235        triton_relu.register_autograd(
236            _triton_relu_backward,
237            setup_context=_triton_relu_setup_context,
238        )
239
240        @register_flop_formula(
241            [torch.ops.testac.triton_relu, torch.ops.testac.triton_relu_backward]
242        )
243        def triton_relu_flops(inp_shape, *args, **kwargs):
244            return prod(inp_shape)
245
246        def f(x, ws):
247            x = torch.ops.testac.triton_relu(x)
248            for w in ws:
249                x = torch.ops.testac.triton_relu(torch.mm(x, w))
250            return x.sum()
251
252        x = torch.randn(512, 512, requires_grad=True, device="cuda")
253        ws = [
254            torch.randn(512, 512, requires_grad=True, device="cuda") for _ in range(5)
255        ]
256
257        def call():
258            return f(x, ws)
259
260        expected = call()
261        for budget in range(0, 11):
262            memory_budget = budget / 10
263            torch._dynamo.reset()
264            with config.patch(activation_memory_budget=memory_budget):
265                if memory_budget is not None:
266                    f_compile = torch.compile(
267                        call, backend="aot_eager_decomp_partition"
268                    )
269
270                self.assertEqual(expected, f_compile())
271
272    def test_prioritize_cheaper_matmul(self):
273        def f(xs, ws):
274            xs = [torch.mm(x, w).cos() for x, w in zip(xs, ws)]
275            return sum(x.sum() for x in xs)
276
277        x1, w1 = create_pair(1, 4)
278        x2, w2 = create_pair(2, 2)
279
280        def call():
281            return f([x1, x2], [w1, w2])
282
283        eager_mem, eager_flops = get_mem_and_flops(call)
284        self.assertEqual(eager_mem, 8)
285        self.assertEqual(eager_flops, 24)
286        comp_mem, comp_flops = get_mem_and_flops(call, memory_budget=0.5)
287        self.assertEqual(comp_mem, 4)
288        # We are recomputing x1 @ w1 here!
289        self.assertEqual(comp_flops, eager_flops + 4)
290
291    @config.patch(activation_memory_budget_runtime_estimator="profile")
292    def test_profile(self):
293        def f(x, ws):
294            x = x.cos()
295            for w in ws:
296                x = torch.mm(x, w).cos()
297            return x.sum()
298
299        x = torch.randn(512, 512, requires_grad=True)
300        ws = [torch.randn(512, 512, requires_grad=True) for _ in range(5)]
301
302        def call():
303            return f(x, ws)
304
305        eager_mem, eager_flops = get_mem_and_flops(call)
306        mem, flops = get_mem_and_flops(call, memory_budget=0.2)
307        # We start saving the matmuls
308        self.assertEqual(mem, 2)
309        self.assertEqual(flops, eager_flops + 3)
310
311    def test_prioritize_cheaper_matmul2(self):
312        def f(xs, ws):
313            xs = [torch.mm(x, w).cos() for x, w in zip(xs, ws)]
314            return sum(x.sum() for x in xs)
315
316        data = [(4, 4), (6, 2), (2, 6)]
317        xs, ws = zip(*[create_pair(a, b) for a, b in data])
318
319        def call():
320            return f(xs, ws)
321
322        eager_mem, eager_flops = get_mem_and_flops(call)
323        self.assertEqual(eager_mem, 40)
324        self.assertEqual(eager_flops, 320)
325        mem, flops = get_mem_and_flops(call, memory_budget=28 / eager_mem)
326        # Save w1 and w2
327        self.assertEqual(mem, 28)
328        # We're recomputing w3 (the cheap one!)
329        self.assertEqual(flops - eager_flops, 2 * 2 * 6)
330        mem, flops = get_mem_and_flops(call, memory_budget=16 / eager_mem)
331        # Save w2. Note that even though saving w1 gets us closer to our memory
332        # limit, w2 is actually *more* FLOPs than w1!
333        self.assertEqual(mem, 12)
334        self.assertEqual(flops - eager_flops, 2 * 2 * 6 + 4 * 4 * 4)
335
336    def test_attention_vs_linear(self):
337        def f(x, w):
338            orig_shape = x.shape
339            x = x.reshape(1, 1, x.shape[0], x.shape[1])
340            # I know this isn't technically right lol
341            x = torch.nn.functional.scaled_dot_product_attention(
342                x, x, x, is_causal=False
343            ).reshape(*orig_shape)
344            x = torch.mm(x, w)
345            x = x.cos()
346            return x.sum()
347
348        def try_seq_length(S, D, expected_recompute):
349            x = torch.randn(S * 512, D * 512, requires_grad=True)
350            w = torch.randn(D * 512, D * 512, requires_grad=True)
351
352            def call():
353                return f(x, w)
354
355            with FlopCounterMode(display=False) as mode:
356                call()
357            mm_flops = mode.get_flop_counts()["Global"][torch.ops.aten.mm]
358            attn_flops = mode.get_total_flops() - mm_flops
359            mm_flops /= 512**3 * 2
360            attn_flops /= 512**3 * 2
361
362            eager_mem, eager_flops = get_mem_and_flops(call)
363            self.assertEqual(eager_mem, S * D * 2)
364
365            mem, flops = get_mem_and_flops(
366                call, memory_budget=0.6
367            )  # Force it to recompute one of mm or attn
368            self.assertEqual(mem, S * D)
369            if expected_recompute == "attn":
370                expected_flops = attn_flops
371            else:
372                expected_flops = mm_flops
373            self.assertEqual(flops - eager_flops, expected_flops)
374
375        # General behind this test is that if sequence length * 2 > D, then
376        # attention is more expensive than the linear.
377        try_seq_length(1, 1, "mm")
378        try_seq_length(1, 3, "attn")
379        try_seq_length(2, 2, "mm")
380        try_seq_length(2, 1, "mm")
381        try_seq_length(2, 5, "attn")
382        try_seq_length(4, 7, "mm")
383        try_seq_length(4, 9, "attn")
384
385
386if __name__ == "__main__":
387    # I'm using the cuda memory allocator to verify memory allocations
388    if HAS_CUDA and not TEST_WITH_ROCM:
389        run_tests()
390