xref: /aosp_15_r20/external/pytorch/test/functorch/test_memory_efficient_fusion.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: functorch"]
2
3import inspect
4import random
5import unittest
6from typing import Callable
7
8import torch
9import torch.fx as fx
10import torch.nn as nn
11from functorch import make_fx
12from functorch.compile import memory_efficient_fusion
13from torch._functorch.compile_utils import fx_graph_cse
14from torch.nn import functional as F
15from torch.testing._internal.common_utils import run_tests, TestCase
16
17
18HAS_CUDA = torch.cuda.is_available()
19
20
21def _num_args(fn: Callable):
22    return len(inspect.signature(fn).parameters)
23
24
25def gelu_bias(bias, y):
26    x = bias + y
27    return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))
28
29
30def swish(x):
31    return x * torch.sigmoid(x)
32
33
34def mish(x):
35    return x.mul(torch.tanh(F.softplus(x)))
36
37
38def hard_sigmoid(x):
39    return (x + 3.0).clamp(min=0.0, max=6.0).div(6.0)
40
41
42def hard_swish(x):
43    return x * (x + 3.0).clamp(min=0.0, max=6.0).div(6.0)
44
45
46def hard_mish(x):
47    return 0.5 * x * (x + 2.0).clamp(min=0.0, max=2.0)
48
49
50# todo: convert these into tests
51# def group_std(x, groups: int = 32, eps: float = 1e-5, flatten: bool = False):
52#     B, C, H, W = x.shape
53#     x_dtype = x.dtype
54#     if flatten:
55#         x = x.reshape(B, groups, -1)  # FIXME simpler shape causing TPU / XLA issues
56#         std = x.float().var(dim=2, unbiased=False, keepdim=True).add(eps).sqrt().to(x_dtype)
57#     else:
58#         x = x.reshape(B, groups, C // groups, H, W)
59#         std = x.float().var(dim=(2, 3, 4), unbiased=False, keepdim=True).add(eps).sqrt().to(x_dtype)
60#     return std.expand(x.shape).reshape(B, C, H, W)
61
62# class EvoNorm2dS0(nn.Module):
63#     def __init__(self, num_features, groups=32, group_size=None, apply_act=True, eps=1e-5, **_):
64#         super().__init__()
65#         self.apply_act = apply_act  # apply activation (non-linearity)
66#         if group_size:
67#             assert num_features % group_size == 0
68#             self.groups = num_features // group_size
69#         else:
70#             self.groups = groups
71#         self.eps = eps
72#         self.weight = nn.Parameter(torch.ones(num_features))
73#         self.bias = nn.Parameter(torch.zeros(num_features))
74#         self.v = nn.Parameter(torch.ones(num_features)) if apply_act else None
75#         self.reset_parameters()
76
77#     def reset_parameters(self):
78#         nn.init.ones_(self.weight)
79#         nn.init.zeros_(self.bias)
80#         if self.v is not None:
81#             nn.init.ones_(self.v)
82
83#     def forward(self, x):
84#         x_dtype = x.dtype
85#         v_shape = (1, -1, 1, 1)
86#         if self.v is not None:
87#             v = self.v.view(v_shape).to(dtype=x_dtype)
88#             x = x * (x * v).sigmoid() / group_std(x, self.groups, self.eps)
89#         return x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype)
90
91
92# device = "cuda"
93# dtype = torch.float
94
95# evo_norm = EvoNorm2dS0(2048)
96# evo_norm_inp = [(128, 2048, 8, 8)]
97
98
99def run_and_compare_activation(self, fn, inps):
100    with torch.jit.fuser("fuser1"):
101        device = "cuda"
102        dtype = torch.float
103        if isinstance(fn, nn.Module):
104            fn = fn.to(device=device, dtype=dtype)
105
106        ref_args = [
107            torch.randn(shape, device=device, dtype=dtype, requires_grad=True)
108            for shape in inps
109        ]
110        res_args = [i.clone().detach().requires_grad_(True) for i in ref_args]
111
112        ref = fn(*ref_args)
113        ref.sum().backward()
114
115        mem_optimized_fn = memory_efficient_fusion(fn)
116        for _ in range(5):
117            for i in res_args:
118                i.grad = None
119            res = mem_optimized_fn(*res_args)
120            res.sum().backward()
121
122        self.assertEqual(ref, res)
123        for ref_arg, res_arg in zip(ref_args, res_args):
124            self.assertEqual(ref_arg.grad, res_arg.grad)
125
126
127@unittest.skipIf(not torch.cuda.is_available(), "CUDA is unavailable")
128class TestMemoryEfficientOpAuthoring(TestCase):
129    def test_gelu_bias(self):
130        run_and_compare_activation(self, gelu_bias, [(1024,), (1024,)])
131
132    def test_mish(self):
133        run_and_compare_activation(self, mish, [(1024,)])
134
135    def test_swish(self):
136        run_and_compare_activation(self, swish, [(1024,)])
137
138    def test_hard_sigmoid(self):
139        run_and_compare_activation(self, hard_sigmoid, [(1024,)])
140
141    def test_hard_swish(self):
142        run_and_compare_activation(self, hard_swish, [(1024,)])
143
144    def test_layer_norm(self):
145        def layer_norm(x, weight, bias):
146            dim = -1
147            eps = 1e-5
148            mean = torch.mean(x, dim, keepdim=True)
149            centered = x - mean
150            var = torch.sum(centered * centered, dim, keepdim=True) / x.size(-1)
151            rvar = 1.0 / torch.sqrt(var + eps)
152            normed = (x - mean) * rvar
153            return normed * weight + bias
154
155        bs = 10
156        ln_size = 16
157        layer_norm_inps = [(bs, ln_size), (ln_size,), (ln_size,)]
158        run_and_compare_activation(self, layer_norm, layer_norm_inps)
159
160    def test_rmsnorm(self):
161        class T5LayerNorm(nn.Module):
162            def __init__(self, hidden_size, eps=1e-6):
163                """
164                Construct a layernorm module in the T5 style No bias and no subtraction of mean.
165                """
166                super().__init__()
167                self.weight = nn.Parameter(torch.ones(hidden_size))
168                self.variance_epsilon = eps
169
170            def forward(self, hidden_states):
171                # layer norm should always be calculated in float32
172                variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
173                hidden_states = hidden_states * torch.rsqrt(
174                    variance + self.variance_epsilon
175                )
176
177                # convert into half-precision if necessary
178                if self.weight.dtype in [torch.float16, torch.bfloat16]:
179                    hidden_states = hidden_states.to(self.weight.dtype)
180
181                return self.weight * hidden_states
182
183        bs = 256
184        seq = 256
185        hidden = 1024
186        t5_norm = T5LayerNorm(hidden)
187        t5_norm_inputs = [(bs, seq, hidden)]
188        run_and_compare_activation(self, t5_norm, t5_norm_inputs)
189
190    # TODO - Assertion failure
191    # def test_hard_mish(self):
192    #   for compiler in compilers:
193    #     run_and_compare_activation(hard_mish, 1024)
194
195
196# check if the CSE modified graph of f has delta less nodes, and do not reduce the number of nodes further on a second pass.
197# delta is an integer >= -1. If delta = -1, only check if the new graph
198#   has less or equal number of nodes
199def check(f, t, delta, check_val=True, graph_input=False):
200    if graph_input:
201        fx_g = f
202    else:
203        fx_g = make_fx(f)(t)
204    new_graph = fx_graph_cse(fx_g.graph)
205    new_g = fx.GraphModule(fx_g, new_graph)
206
207    # the number of nodes decrease/ or stay the same
208    old_num_nodes = len(fx_g.graph.nodes)
209    new_num_nodes = len(new_graph.nodes)
210    if delta == -1:
211        assert (
212            old_num_nodes >= new_num_nodes
213        ), f"number of nodes increased {old_num_nodes}, {new_num_nodes}"
214    else:
215        assert (
216            old_num_nodes == new_num_nodes + delta
217        ), f"number of nodes not the same {old_num_nodes - delta}, {new_num_nodes}\n {fx_g.graph} \n {new_graph}"
218
219    # a second pass should not reduce more nodes
220    pass_2_graph = fx_graph_cse(new_graph)
221    pass_2_num_nodes = len(pass_2_graph.nodes)
222    assert (
223        pass_2_num_nodes == new_num_nodes
224    ), f"second pass graph has less node {pass_2_num_nodes}, {new_num_nodes}\n {new_graph} \n {pass_2_graph}"
225
226    # check correctness
227    if check_val:
228        true_result = fx_g(t)
229        our_result = new_g(t)
230        if true_result is None:  # both return None
231            assert (
232                our_result is None
233            ), f"true result is None, CSE result is {our_result}"
234        else:  # results returned are the same
235            assert torch.all(
236                true_result == our_result
237            ), f"results are different {true_result}, {our_result}"  # check results are the same
238
239
240class NoChangeTestCase(TestCase):
241    def test_nochange(self):
242        def f(x):
243            a = x + 1
244            b = x + a
245            a = x
246            d = x + a
247            return b + d
248
249        t = torch.randn(2, 2)
250        check(f, t, 0)
251
252    def test_empty(self):
253        def f(x):
254            pass
255
256        t = torch.randn(2, 2)
257        check(f, t, 0)
258
259    def test_rand_like(self):
260        def f(x):
261            a = torch.rand_like(x)
262            b = torch.rand_like(x)
263            return a + b
264
265        t = torch.randn(2, 2)
266        check(f, t, 0, check_val=False)
267
268    def test_rand_n(self):
269        def f(x):
270            a = torch.randn(4)
271            b = torch.randn(4)
272            return a + b
273
274        t = torch.randn(2, 2)
275        check(f, t, 0, check_val=False)
276
277    def test_hash_with_numbers(self):
278        # Test to repro issue with fx_graph_cse when
279        # hash((primals_2, 1.0)) == hash((primals_2, 1))
280
281        if torch._dynamo.is_compiling():
282            self.skipTest("Unsupported if test run is compiled")
283
284        def f(inpt, osize):
285            size = inpt.shape[-1]
286            s1 = size - 1
287            s2 = size - 1.0
288            scale = s2 / (osize - 1.0)
289            inpt = torch.clamp(inpt, 0, s1)
290            return scale * inpt
291
292        # Fetch dynamic graph
293        gms = []
294
295        def toy_backend(gm, _):
296            gms.append(gm)
297            return gm.forward
298
299        torch._dynamo.reset()
300        fn = torch.compile(backend=toy_backend, dynamic=True)(f)
301
302        t = torch.rand(3, 100)
303        _ = fn(t, 50)
304        assert len(gms) == 1, gms
305        fx_g = gms[0]
306        check(fx_g, None, 0, check_val=False, graph_input=True)
307
308
309class ReduceTestCase(TestCase):
310    def test_immutable_list_type(self):
311        def f(x):
312            a = x.sum(dim=1)
313            b = x.sum(dim=1)
314            c = x.sum()
315            d = x.sum()
316            return a + b + c + d
317
318        t = torch.randn(2, 2)
319        check(f, t, 2)
320
321    def test_immutable_list_multiple_entries(self):
322        def f(x):
323            a = x.sum(dim=[0, 1])
324            b = x.sum(dim=[0, 1])
325            c = x.sum(dim=1)
326            d = x.sum(dim=1)
327            return a + b + c + d
328
329        t = torch.randn(2, 2)
330        check(f, t, 2)
331
332    def test_simple(self):
333        def f(x):
334            a = x.cos()
335            b = x.cos()
336            c = a + a
337            d = b + b
338            return c + d
339
340        t = torch.randn(2, 2)
341        check(f, t, 2)
342
343    def test_simple_2(self):
344        def f(x):
345            a = x.cos().sin()
346            b = x.cos().sin()
347            c = a + a
348            d = b + b
349            return c + d
350
351        t = torch.randn(1)
352        check(f, t, 3)
353
354    def test_two_args_default(self):
355        def f(x):
356            a = x.sum(dim=1)
357            b = x.sum(dim=1, keepdim=False)
358            c = x.sum(dim=1, keepdim=False)
359            d = x.sum(dim=1)
360            return a + b + c + d
361
362        t = torch.randn(2, 2)
363        check(f, t, 3)
364
365    def test_two_args(self):
366        def f(x):
367            a = x.sum(dim=1)
368            b = x.sum(dim=1, keepdim=True)
369            c = x.sum(dim=1, keepdim=True)
370            d = x.sum(dim=1)
371            return a + b + c + d
372
373        t = torch.randn(2, 2)
374        check(f, t, 2)
375
376    def test_simple_multiple_same_ops(self):
377        def f(x):
378            a = x.sum()
379            b = x.sum()
380            c = x.sum()
381            d = x.sum()
382            return a + b + c + d
383
384        t = torch.randn(2, 2)
385        check(f, t, 3)
386
387    def test_nested_immutable_list_type(self):
388        def f(x):
389            a = torch.cat((x, x))
390            b = torch.cat((x, x))
391            return a + b
392
393        t = torch.randn(2, 2)
394        check(f, t, 1)
395
396    def test_kwarg(self):
397        def f(x):
398            a = torch.ones_like(x)
399            b = torch.ones_like(x)
400            return a + b
401
402        t = torch.randn(2, 2)
403        check(f, t, 1)
404
405
406class RandomOpTestCase(TestCase):
407    def test_random(self):
408        def f(x):
409            vals = [x]
410            ops = [torch.clone, torch.cos, torch.tanh, torch.nn.functional.gelu]
411            for _ in range(100):
412                new_val = random.choice(ops)(random.choice(vals))
413                vals.append(new_val)
414            return vals[-1]
415
416        fx_g = fx.symbolic_trace(f)
417        fx_g.graph.eliminate_dead_code()
418        fx_g.recompile()
419        t = torch.randn(2, 2)
420
421        for _ in range(30):
422            check(fx_g, t, -1, graph_input=True)
423
424
425if __name__ == "__main__":
426    run_tests()
427