xref: /aosp_15_r20/external/pytorch/test/inductor/test_codecache.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: inductor"]
2import functools
3import os
4import pickle
5import unittest
6from typing import List
7from unittest import mock
8
9import torch
10from torch._dynamo import reset
11from torch._dynamo.utils import counters
12from torch._inductor import config, metrics
13from torch._inductor.async_compile import AsyncCompile
14from torch._inductor.codecache import (
15    cuda_compile_command,
16    CUDACodeCache,
17    FxGraphCachePickler,
18    FxGraphHashDetails,
19    PyCodeCache,
20    TensorMetadata,
21    TensorMetadataAndValues,
22)
23from torch._inductor.graph import GraphLowering
24from torch._inductor.runtime.runtime_utils import cache_dir
25from torch._inductor.test_case import run_tests, TestCase
26from torch._inductor.utils import clear_inductor_caches, fresh_inductor_cache
27from torch.testing._internal.common_cuda import SM80OrLater
28from torch.testing._internal.common_device_type import largeTensorTest
29from torch.testing._internal.common_utils import (
30    instantiate_parametrized_tests,
31    parametrize,
32)
33from torch.testing._internal.inductor_utils import (
34    GPU_TYPE,
35    HAS_CUDA,
36    HAS_GPU,
37    HAS_MULTIGPU,
38    requires_gpu,
39)
40from torch.utils._triton import has_triton
41
42
43try:
44    from .mock_cache import global_stats, patch_fbcode, PatchCaches
45except ImportError:
46    from mock_cache import global_stats, patch_fbcode, PatchCaches  # @manual
47
48
49HAS_TRITON = has_triton()
50
51if HAS_TRITON:
52    import triton  # @manual
53
54    from torch.testing._internal.triton_utils import add_kernel
55
56requires_triton = functools.partial(unittest.skipIf, not HAS_TRITON, "requires triton")
57
58torch._dynamo.config.fake_tensor_cache_enabled = True
59torch._dynamo.config.fake_tensor_cache_crosscheck_enabled = True
60
61
62class MyModel(torch.nn.Module):
63    def __init__(self) -> None:
64        super().__init__()
65        self.fc1 = torch.nn.Linear(10, 10)
66
67    def forward(self, inp):
68        return self.fc1(inp)
69
70
71def _run_codecache_test(start_method):
72    with torch._inductor.config.patch(
73        worker_start_method=start_method, compile_threads=16
74    ):
75        AsyncCompile.warm_pool()
76
77        model = MyModel().to(device=GPU_TYPE)
78        model = torch.compile(model)
79        inp = torch.rand(10, 10).to(device=GPU_TYPE)
80        model(inp).sum().backward()
81
82
83@requires_gpu()
84def test_codecache_spawn():
85    _run_codecache_test("spawn")
86
87
88@requires_gpu()
89def test_codecache_fork():
90    _run_codecache_test("fork")
91
92
93class MyModelConv2d(torch.nn.Module):
94    def __init__(self, dim=512):
95        super().__init__()
96        self.conv1 = torch.nn.Conv2d(3, dim, kernel_size=3, stride=2, bias=False)
97        self.conv2 = torch.nn.Conv2d(dim, dim, kernel_size=3, stride=2, bias=False)
98
99    def forward(self, x):
100        x = self.conv1(x)
101        torch._dynamo.graph_break()
102        x = self.conv2(x)
103        return x
104
105
106@instantiate_parametrized_tests
107class TestFxGraphCache(TestCase):
108    device_type = GPU_TYPE
109
110    def setUp(self):
111        super().setUp()
112        counters.clear()
113        PatchCaches.setUp()
114
115    def tearDown(self):
116        super().tearDown()
117        PatchCaches.tearDown()
118
119    def reset(self):
120        torch._dynamo.reset()
121        clear_inductor_caches()
122
123    @requires_triton()
124    @config.patch({"fx_graph_cache": True})
125    @config.patch({"fx_graph_remote_cache": False})
126    @parametrize("device", (GPU_TYPE, "cpu"))
127    @parametrize("dtype", (torch.float32, torch.bfloat16))
128    @parametrize("dynamic", (False, True))
129    def test_cache_load_function(self, device, dtype, dynamic):
130        """
131        Verify that we can populate and load functions from the cache.
132        """
133        if device == GPU_TYPE and not HAS_GPU:
134            raise unittest.SkipTest(f"requires {GPU_TYPE}")
135        if device == "cuda" and dtype == torch.bfloat16 and not SM80OrLater:
136            raise unittest.SkipTest("requires SM80 or later")
137
138        def fn(x, y):
139            return (x * 2, y @ y)
140
141        a = torch.rand(25, dtype=dtype, device=device)
142        b = torch.rand(5, 5, dtype=dtype, device=device)
143
144        compiled_fn = torch.compile(fn, dynamic=dynamic)
145
146        # A first call should miss in the cache.
147        self.assertEqual(fn(a, b), compiled_fn(a, b))
148        self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
149        self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
150        self.assertEqual(counters["inductor"]["fxgraph_lookup_write_file"], 0)
151
152        # A second call should hit. (First reset so in-memory guards
153        # don't prevent compilation).
154        for m in torch._inductor.codecache.PyCodeCache.cache.values():
155            os.remove(m.__file__)
156        self.reset()
157        self.assertEqual(fn(a, b), compiled_fn(a, b))
158        self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
159        self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1)
160        self.assertEqual(counters["inductor"]["fxgraph_lookup_write_file"], 1)
161
162    @requires_triton()
163    @config.patch({"fx_graph_remote_cache": True})
164    @parametrize("device", (GPU_TYPE, "cpu"))
165    @parametrize("dtype", (torch.float32, torch.bfloat16))
166    @parametrize("dynamic", (False, True))
167    def test_remote_cache_load_function(self, device, dtype, dynamic):
168        from unittest.mock import patch
169
170        if device == GPU_TYPE and not HAS_GPU:
171            raise unittest.SkipTest(f"requires {GPU_TYPE}")
172        if device == "cuda" and dtype == torch.bfloat16 and not SM80OrLater:
173            raise unittest.SkipTest("requires SM80 or later")
174
175        def fn(x, y):
176            return (x * 2, y @ y)
177
178        a = torch.rand(25, dtype=dtype, device=device)
179        b = torch.rand(5, 5, dtype=dtype, device=device)
180
181        with config.patch(
182            {
183                "fx_graph_remote_cache": True,
184            }
185        ), patch.dict(os.environ), PatchCaches():
186            os.environ.pop("TRITON_CACHE_MANAGER", None)
187            for _ in range(4):
188                with fresh_inductor_cache():
189                    compiled_fn = torch.compile(fn, dynamic=dynamic)
190                    self.assertEqual(fn(a, b), compiled_fn(a, b))
191                reset()
192
193        global_stats.report()
194        self.assertEqual(global_stats.fx_graph.num_get_hit, 3)
195        self.assertEqual(global_stats.fx_graph.num_get_miss, 1)
196        self.assertEqual(global_stats.fx_graph.num_put, 1)
197
198    @requires_triton()
199    @config.patch({"fx_graph_cache": True})
200    @config.patch({"fx_graph_remote_cache": False})
201    @parametrize("device", (GPU_TYPE, "cpu"))
202    @parametrize("dtype", (torch.float32, torch.float64))
203    @parametrize("dynamic", (False, True))
204    def test_cache_load_model(self, device, dtype, dynamic):
205        """
206        Verify that we can populate and load models from the cache.
207        """
208        if device == GPU_TYPE and not HAS_GPU:
209            raise unittest.SkipTest(f"requires {GPU_TYPE}")
210
211        def fn(mod, x):
212            mod.zero_grad()
213            mod(x).sum().backward()
214            return [p.grad for p in mod.parameters()]
215
216        compiled_fn = torch.compile(fn, dynamic=dynamic)
217
218        mod = MyModelConv2d().to(device=device, dtype=dtype)
219        inp = torch.randn(2, 3, 16, 16, device=device, dtype=dtype)
220
221        # The first call should see all cache misses.
222        counters.clear()
223        grads1 = compiled_fn(mod, inp)
224        self.assertGreater(counters["inductor"]["fxgraph_cache_miss"], 0)
225        self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
226
227        # The second should see all hits. (First reset so in-memory guards
228        # don't prevent compilation).
229        counters.clear()
230        self.reset()
231        grads2 = compiled_fn(mod, inp)
232        self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 0)
233        self.assertGreater(counters["inductor"]["fxgraph_cache_hit"], 0)
234
235        # And the results should be the same.
236        self.assertEqual(grads1, grads2)
237
238    @largeTensorTest("64GB", device=GPU_TYPE)
239    @config.patch({"fx_graph_cache": True})
240    @config.patch({"fx_graph_remote_cache": False})
241    @parametrize("device", (GPU_TYPE,))
242    @parametrize("dtype", (torch.float16, torch.bfloat16))
243    def test_cache_load_with_guards_int32_bounds(self, device, dtype):
244        """
245        Test caching the same graph, but under conditions that introduce guards
246        for tensor sizes < int32.
247        """
248        if device == GPU_TYPE and not HAS_GPU:
249            raise unittest.SkipTest(f"requires {GPU_TYPE}")
250        if device == "cuda" and dtype == torch.bfloat16 and not SM80OrLater:
251            raise unittest.SkipTest("requires CUDA SM80 or later")
252
253        def fn(x, y):
254            return (x + x, y + y)
255
256        compiled_fn = torch.compile(fn, dynamic=True)
257
258        # Iterate over different shapes, varying whether the total
259        # size is below or above int32. For each combination, we expect
260        # different guards around whether the symbolic sizes do or do
261        # not exceed int32.
262        shapes = (
263            ((5, 6), (7, 8)),
264            ((5, 6), (47000, 47001)),
265            ((47000, 47001), (5, 6)),
266        )
267        for a_shape, b_shape in shapes:
268            a = torch.rand(a_shape, device=device, dtype=dtype)
269            b = torch.rand(b_shape, device=device, dtype=dtype)
270
271            # AVOID a dynamo reset here. We expect guards to have been
272            # added that will be violated with the new shape. We should
273            # see a recompilation (along with a cache miss).
274            counters.clear()
275            res1 = compiled_fn(a, b)
276            self.assertGreater(counters["inductor"]["fxgraph_cache_miss"], 0)
277            self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
278
279            # A second call should hit. (Reset here to force compilation).
280            counters.clear()
281            self.reset()
282            res2 = compiled_fn(a, b)
283            self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 0)
284            self.assertGreater(counters["inductor"]["fxgraph_cache_hit"], 0)
285
286            self.assertEqual(res1, res2)
287
288    @config.patch({"fx_graph_cache": True})
289    @config.patch({"fx_graph_remote_cache": False})
290    @parametrize("device", (GPU_TYPE, "cpu"))
291    @parametrize("dtype", (torch.float32, torch.bfloat16))
292    def test_cache_load_with_guards_static_bounds(self, device, dtype):
293        """
294        Test caching the same graph, but under conditions that introduce guards
295        for static bounds.
296        """
297        if device == GPU_TYPE and not HAS_GPU:
298            raise unittest.SkipTest(f"requires {GPU_TYPE}")
299        if device == "cuda" and dtype == torch.bfloat16 and not SM80OrLater:
300            raise unittest.SkipTest("requires SM80 or later")
301
302        # See lowering; for all of the pooling operators, we always guard and
303        # make the height/width static.
304        def fn(x):
305            return torch.nn.functional.adaptive_avg_pool2d(x, [5, 7])
306
307        compiled_fn = torch.compile(fn, dynamic=True)
308
309        # Iterate over different input shapes. Each new shape should cause
310        # a cache miss.
311        shapes = ((1, 64, 8, 9), (1, 64, 9, 10), (1, 64, 10, 11))
312        for shape in shapes:
313            x = torch.rand(shape, device=device, dtype=dtype)
314
315            # AVOID a dynamo reset here. For each cache hit, we expect guards
316            # to have been added that will be violated with each new shape.
317            # We should see a recompilation (along with a cache miss).
318            counters.clear()
319            res1 = compiled_fn(x)
320            self.assertGreater(counters["inductor"]["fxgraph_cache_miss"], 0)
321            self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
322
323            # A second call should hit.
324            counters.clear()
325            self.reset()
326            res2 = compiled_fn(x)
327            self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 0)
328            self.assertGreater(counters["inductor"]["fxgraph_cache_hit"], 0)
329
330            self.assertEqual(res1, res2)
331
332    @config.patch({"fx_graph_cache": True})
333    @config.patch({"fx_graph_remote_cache": False})
334    @parametrize("device", (GPU_TYPE, "cpu"))
335    def test_constant_handling(self, device):
336        """
337        Test that different constants are recognized correctly.
338        """
339        if device == GPU_TYPE and not HAS_GPU:
340            raise unittest.SkipTest(f"requires {GPU_TYPE}")
341
342        def fn1(x):
343            return x + torch.tensor(list(range(0, 12)), device=device)
344
345        def fn2(x):
346            return x + torch.tensor(list(range(1, 13)), device=device)
347
348        a = torch.rand(12, device=device)
349
350        compiled_fn1 = torch.compile(fn1)
351        compiled_fn2 = torch.compile(fn2)
352
353        # A call to fn1 should miss in the cache.
354        self.assertEqual(fn1(a), compiled_fn1(a))
355        self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
356        self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
357
358        # A call to fn2 should also miss (the constant is different)
359        self.assertEqual(fn2(a), compiled_fn2(a))
360        self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 2)
361        self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
362
363    @requires_gpu()
364    @requires_triton()
365    @config.patch({"fx_graph_cache": True})
366    @config.patch({"fx_graph_remote_cache": False})
367    def test_higher_order_op_bypass(self):
368        """
369        Verify that we bypass the cache when we have higher order ops.
370        """
371
372        def fn(x, y):
373            output = torch.zeros_like(x)
374            n_elements = output.numel()
375            grid = lambda meta: (  # noqa: E731
376                triton.cdiv(n_elements, meta["BLOCK_SIZE"]),
377            )
378            add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=4)
379            return output
380
381        compiled_fn = torch.compile(fn, fullgraph=True)
382
383        x = torch.randn(4, device=GPU_TYPE)
384        y = torch.randn(4, device=GPU_TYPE)
385        compiled_fn(x, y)
386
387        self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 0)
388        self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
389        self.assertGreater(counters["inductor"]["fxgraph_cache_bypass"], 0)
390
391    @config.patch({"fx_graph_cache": True})
392    @config.patch({"fx_graph_remote_cache": False})
393    def test_generated_kernel_count(self):
394        """
395        Test that we bump the generated_kernel_count metric on a cache hit.
396        """
397
398        def fn(x, y):
399            return (x * y + y,)
400
401        a = torch.rand(5, 5)
402        b = torch.rand(5, 5)
403
404        compiled_fn = torch.compile(fn)
405
406        metrics.reset()
407        self.assertEqual(metrics.generated_kernel_count, 0)
408
409        # Verify the "miss" case.
410        self.assertEqual(fn(a, b), compiled_fn(a, b))
411        self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
412        self.assertEqual(metrics.generated_kernel_count, 1)
413
414        # Verify the "hit" case
415        self.reset()
416        self.assertEqual(fn(a, b), compiled_fn(a, b))
417        self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1)
418        self.assertEqual(metrics.generated_kernel_count, 2)
419
420    @config.patch({"fx_graph_cache": True})
421    @config.patch({"fx_graph_remote_cache": False})
422    def test_inductor_counters(self):
423        """
424        Test that we bump the inductor counters on a cache hit.
425        """
426        compile_to_fn = GraphLowering.compile_to_fn
427
428        counter_name = "a_test_counter"
429        counter_incr = 7
430
431        def bump_counter(self):
432            # Mock that bumps some arbitrary test counter by a set amount, then calls
433            # the original GraphLowering.compile_to_fn.
434            counters["inductor"][counter_name] += counter_incr
435            return compile_to_fn(self)
436
437        with mock.patch.object(GraphLowering, "compile_to_fn", bump_counter):
438
439            def fn(a, b):
440                return torch.mm(a, b)
441
442            a = torch.rand(8, 32, device="cpu")
443            b = torch.rand(32, 8, device="cpu")
444
445            compiled_fn = torch.compile(fn)
446
447            # Verify the "miss" case.
448            counter_val = 2
449            counters["inductor"][counter_name] = counter_val
450            self.assertEqual(fn(a, b), compiled_fn(a, b))
451            self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
452            self.assertEqual(
453                counters["inductor"][counter_name], counter_val + counter_incr
454            )
455
456            # Verify the "hit" case.
457            self.reset()
458            counter_val = 5
459            counters["inductor"][counter_name] = counter_val
460            self.assertEqual(fn(a, b), compiled_fn(a, b))
461            self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1)
462            self.assertEqual(
463                counters["inductor"][counter_name], counter_val + counter_incr
464            )
465
466    @config.patch({"fx_graph_cache": True})
467    @config.patch({"fx_graph_remote_cache": False})
468    def test_cache_clear(self):
469        """
470        Test clearing the cache.
471        """
472
473        def fn(x, y):
474            return (x * y,)
475
476        a = torch.rand(5, 5)
477        b = torch.rand(5, 5)
478
479        compiled_fn = torch.compile(fn)
480
481        # A first call should miss in the cache.
482        self.assertEqual(fn(a, b), compiled_fn(a, b))
483        self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
484        self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
485
486        # A second call should hit.
487        counters.clear()
488        self.reset()
489        self.assertEqual(fn(a, b), compiled_fn(a, b))
490        self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 0)
491        self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1)
492
493        # Clear the cache; now we should miss.
494        counters.clear()
495        self.reset()
496        torch._inductor.codecache.FxGraphCache.clear()
497        self.assertEqual(fn(a, b), compiled_fn(a, b))
498        self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
499        self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
500
501    @config.patch({"fx_graph_cache": True})
502    @config.patch({"fx_graph_remote_cache": False})
503    def test_cache_with_nt(self):
504        def gen_nt(r):
505            values = torch.randn(r, 16)
506            offsets = torch.tensor([0, 2, 3, 6, 13, r])
507            return torch.nested.nested_tensor_from_jagged(values, offsets)
508
509        def fn(nt):
510            if nt.values().size(0) % 16 == 0:
511                return nt.sin()
512            return nt.cos()
513
514        inp1 = gen_nt(19)
515        inp2 = gen_nt(20)
516
517        counters.clear()
518        torch.compile(fn)(inp1)
519        torch.compile(fn)(inp2)
520        self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
521        self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
522
523        self.reset()
524        counters.clear()
525        torch.compile(fn)(inp1)
526        torch.compile(fn)(inp2)
527        self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 0)
528        self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1)
529
530    @config.patch({"fx_graph_cache": True})
531    @config.patch({"fx_graph_remote_cache": False})
532    def test_cache_with_symint_non_arg_guard(self):
533        def fn(x, ref_id):
534            self_id = 22
535            if self_id == ref_id:
536                x = torch.mul(x, 1.0)
537            else:
538                x = torch.mul(x, 0)
539            return x
540
541        x = torch.ones(2)
542
543        counters.clear()
544        torch.compile(fn, fullgraph=True, dynamic=True)(x, 2)
545        self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
546        self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
547
548        self.reset()
549        counters.clear()
550        torch.compile(fn, fullgraph=True, dynamic=True)(x, 2)
551        self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 0)
552        self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1)
553
554    @config.patch({"fx_graph_cache": True})
555    @config.patch({"fx_graph_remote_cache": False})
556    def test_cache_guard(self):
557        def f(x, val):
558            if val > 5:
559                return x.sin()
560            else:
561                return x.cos()
562
563        x = torch.ones(2)
564        a = torch.compile(f, dynamic=True)(x, 6)
565        self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
566        self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
567
568        self.reset()
569        counters.clear()
570        b = torch.compile(f, dynamic=True)(x, 4)
571        self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
572        self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
573
574        self.assertNotEqual(a, b)
575
576
577class TestFxGraphCacheHashing(TestCase):
578    def test_tensor_constants(self):
579        """
580        Test the hashing of tensor constants.
581        """
582        data = FxGraphCachePickler.dumps(torch.tensor(list(range(9))))
583        self.assertIsInstance(pickle.loads(data), TensorMetadataAndValues)
584
585    def test_hash_fake_tensors(self):
586        """
587        Test hashing (pickling) FakeTensors with various characteristics.
588        """
589        with torch._subclasses.FakeTensorMode():
590            # Verify that FakeTensors get pickled into a TensorMetadata:
591            data = FxGraphCachePickler.dumps(torch.randn(1))
592            self.assertIsInstance(pickle.loads(data), TensorMetadata)
593
594            # Different shapes:
595            self.assertEqual(
596                FxGraphCachePickler.dumps(torch.randn(3)),
597                FxGraphCachePickler.dumps(torch.randn(3)),
598            )
599            self.assertNotEqual(
600                FxGraphCachePickler.dumps(torch.randn(3)),
601                FxGraphCachePickler.dumps(torch.randn(4)),
602            )
603            self.assertNotEqual(
604                FxGraphCachePickler.dumps(torch.randn(3)),
605                FxGraphCachePickler.dumps(torch.randn(3, 3)),
606            )
607
608            self.assertEqual(
609                FxGraphCachePickler.dumps(torch.randn(3, 3)),
610                FxGraphCachePickler.dumps(torch.randn(3, 3)),
611            )
612            self.assertNotEqual(
613                FxGraphCachePickler.dumps(torch.randn(3, 3)),
614                FxGraphCachePickler.dumps(torch.randn(3, 4)),
615            )
616            self.assertNotEqual(
617                FxGraphCachePickler.dumps(torch.randn(3, 3)),
618                FxGraphCachePickler.dumps(torch.randn(4, 3)),
619            )
620
621            # Different strides:
622            self.assertEqual(
623                FxGraphCachePickler.dumps(torch.randn(3, 3)),
624                FxGraphCachePickler.dumps(
625                    torch.randn(3, 3).transpose(0, 1).transpose(0, 1)
626                ),
627            )
628            self.assertNotEqual(
629                FxGraphCachePickler.dumps(torch.randn(3, 3)),
630                FxGraphCachePickler.dumps(torch.randn(3, 3).transpose(0, 1)),
631            )
632
633            # Different storage offsets:
634            self.assertEqual(
635                FxGraphCachePickler.dumps(torch.randn(3)[1:]),
636                FxGraphCachePickler.dumps(torch.randn(3)[1:]),
637            )
638            self.assertEqual(
639                FxGraphCachePickler.dumps(torch.randn(3)[1:]),
640                FxGraphCachePickler.dumps(torch.randn(2)),
641            )
642
643            # Different dtypes:
644            self.assertEqual(
645                FxGraphCachePickler.dumps(torch.randn(3, dtype=torch.float32)),
646                FxGraphCachePickler.dumps(torch.randn(3, dtype=torch.float32)),
647            )
648            self.assertNotEqual(
649                FxGraphCachePickler.dumps(torch.randn(3, dtype=torch.float32)),
650                FxGraphCachePickler.dumps(torch.randn(3, dtype=torch.float64)),
651            )
652
653            # Different 'requires_grad':
654            self.assertEqual(
655                FxGraphCachePickler.dumps(torch.randn(3, requires_grad=True)),
656                FxGraphCachePickler.dumps(torch.randn(3, requires_grad=True)),
657            )
658            self.assertNotEqual(
659                FxGraphCachePickler.dumps(torch.randn(3, requires_grad=True)),
660                FxGraphCachePickler.dumps(torch.randn(3, requires_grad=False)),
661            )
662
663            # Different memory formats:
664            self.assertNotEqual(
665                FxGraphCachePickler.dumps(torch.randn(1, 2, 3, 4)),
666                FxGraphCachePickler.dumps(
667                    torch.randn(1, 2, 3, 4).to(memory_format=torch.channels_last)
668                ),
669            )
670
671            # Different devices:
672            self.assertEqual(
673                FxGraphCachePickler.dumps(torch.randn(3, device="meta")),
674                FxGraphCachePickler.dumps(torch.randn(3, device="meta")),
675            )
676            self.assertNotEqual(
677                FxGraphCachePickler.dumps(torch.randn(3, device="meta")),
678                FxGraphCachePickler.dumps(torch.randn(3, device="cpu")),
679            )
680
681            if HAS_MULTIGPU:
682                self.assertEqual(
683                    FxGraphCachePickler.dumps(torch.randn(3, device=f"{GPU_TYPE}:1")),
684                    FxGraphCachePickler.dumps(torch.randn(3, device=f"{GPU_TYPE}:1")),
685                )
686                self.assertNotEqual(
687                    FxGraphCachePickler.dumps(torch.randn(3, device=f"{GPU_TYPE}:0")),
688                    FxGraphCachePickler.dumps(torch.randn(3, device=f"{GPU_TYPE}:1")),
689                )
690
691    def test_hash_kwargs(self):
692        """
693        Test the special handling of the kwargs when hashing, i.e.,
694        ordering of the kwargs dict and any set arguments.
695        """
696        # Dict order of the kwargs should not affect hashes.
697        details1 = FxGraphHashDetails(None, [], {"a": 0, "z": 1}, [])
698        details2 = FxGraphHashDetails(None, [], {"z": 1, "a": 0}, [])
699        self.assertEqual(
700            FxGraphCachePickler.dumps(details1),
701            FxGraphCachePickler.dumps(details2),
702        )
703
704        # Different kwarg values should affect hashes.
705        details1 = FxGraphHashDetails(None, [], {"a": 0}, [])
706        details2 = FxGraphHashDetails(None, [], {"a": 1}, [])
707        self.assertNotEqual(
708            FxGraphCachePickler.dumps(details1),
709            FxGraphCachePickler.dumps(details2),
710        )
711
712        # Set order should not affect hashes. Sets are unordered, but
713        # sorting and creating a new set seems to change the order.
714        set1 = {"a", "b", "c", "d", "e", "f", "g"}
715        set2 = set(sorted(set1))  # noqa: C414
716        details1 = FxGraphHashDetails(None, [], {"a": set1}, [])
717        details2 = FxGraphHashDetails(None, [], {"a": set2}, [])
718        self.assertEqual(
719            FxGraphCachePickler.dumps(details1),
720            FxGraphCachePickler.dumps(details2),
721        )
722
723        # But different set contents should affect hashes.
724        details1 = FxGraphHashDetails(None, [], {"a": {1, 2, 3}}, [])
725        details2 = FxGraphHashDetails(None, [], {"a": {1, 2}}, [])
726        self.assertNotEqual(
727            FxGraphCachePickler.dumps(details1),
728            FxGraphCachePickler.dumps(details2),
729        )
730
731    def test_hash_config_changes(self):
732        """
733        Test that different config settings affect hashes.
734        """
735        with config.patch({"max_autotune": False}):
736            details1 = FxGraphHashDetails(None, [], {}, [])
737            details2 = FxGraphHashDetails(None, [], {}, [])
738
739        with config.patch({"max_autotune": True}):
740            details3 = FxGraphHashDetails(None, [], {}, [])
741
742        self.assertEqual(
743            FxGraphCachePickler.dumps(details1),
744            FxGraphCachePickler.dumps(details2),
745        )
746        self.assertNotEqual(
747            FxGraphCachePickler.dumps(details1),
748            FxGraphCachePickler.dumps(details3),
749        )
750
751    @unittest.skipIf(not HAS_CUDA, "Requires CUDA")
752    @unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup")
753    def test_cuda_compile_command(self):
754        cmd_no_extra_args: str = cuda_compile_command(
755            ["abc.cu", "def.cu"], "output", "so"
756        )
757        assert "nvcc " in cmd_no_extra_args, cmd_no_extra_args
758        assert "abc.cu" in cmd_no_extra_args, cmd_no_extra_args
759        assert "def.cu" in cmd_no_extra_args, cmd_no_extra_args
760        assert "output" in cmd_no_extra_args, cmd_no_extra_args
761        cmd_extra_args: str = cuda_compile_command(
762            ["abc.cu", "def.cu"], "output", "so", ["-Wwhatever", "-nothing"]
763        )
764        assert "nvcc " in cmd_extra_args, cmd_extra_args
765        assert " -Wwhatever" in cmd_extra_args, cmd_extra_args
766        assert " -nothing" in cmd_extra_args, cmd_extra_args
767        assert "abc.cu" in cmd_extra_args, cmd_extra_args
768        assert "def.cu" in cmd_extra_args, cmd_extra_args
769        assert "output " in cmd_extra_args, cmd_extra_args
770        with mock.patch("subprocess.check_output") as check_output_mock:
771            CUDACodeCache.compile("test123.cu", "so", ["-Wsomething"])
772            check_output_mock.assert_called()
773            cmd_parts: List[str] = check_output_mock.call_args[0][0]
774            assert cmd_parts[0] == "nvcc", cmd_parts
775            assert "-Wsomething" in cmd_parts, cmd_parts
776            assert "-DNDEBUG" in cmd_parts, cmd_parts
777
778
779@instantiate_parametrized_tests
780class TestAutotuneCache(TestCase):
781    device_type = GPU_TYPE
782
783    def setUp(self):
784        super().setUp()
785        counters.clear()
786        PatchCaches.setUp()
787
788    def tearDown(self):
789        super().tearDown()
790        PatchCaches.tearDown()
791
792    def reset(self):
793        torch._dynamo.reset()
794        clear_inductor_caches()
795
796    @unittest.skipIf(not HAS_CUDA, "Requires CUDA")
797    @unittest.skipIf(not SM80OrLater, "Requires SM80+")
798    @config.patch({"fx_graph_cache": False})
799    @config.patch({"fx_graph_remote_cache": False})
800    @config.patch({"autotune_local_cache": False})
801    @config.patch({"autotune_remote_cache": True})
802    @config.patch({"max_autotune": True})
803    @parametrize("fbcode", (False,) + (True,) * config.is_fbcode())
804    def test_autotune_cache(self, fbcode: bool):
805        class Model(torch.nn.Module):
806            def forward(self, x, y, a, b):
807                return x + y, a + b
808
809        def f(x, y, a, b):
810            return Model()(x, y, a, b)
811
812        x = torch.randn(100, 100).cuda()
813        y = torch.randn(100, 100).cuda()
814        a = torch.randn(1000, 100).cuda()
815        b = torch.randn(1000, 100).cuda()
816        f_compiled = torch.compile(f, fullgraph=True)
817
818        with PatchCaches(), patch_fbcode(fbcode):
819            f_compiled(x, y, a, b)
820
821            self.assertEqual(global_stats.autotune.num_get_hit, 0)
822            self.assertEqual(global_stats.autotune.num_get_miss, 2)
823            self.assertEqual(global_stats.autotune.num_put, 2)
824
825            self.reset()
826            f_compiled(x, y, a, b)
827
828        global_stats.report()
829        self.assertEqual(global_stats.autotune.num_get_hit, 2)
830        self.assertEqual(global_stats.autotune.num_get_miss, 2)
831        self.assertEqual(global_stats.autotune.num_put, 2)
832
833
834class TestUtils(TestCase):
835    @config.patch({"fx_graph_remote_cache": False})
836    def test_fresh_inductor_cache(self):
837        def fn(x, y):
838            return x + y
839
840        a = torch.rand(10)
841        b = torch.rand(10)
842
843        with fresh_inductor_cache():
844            self.assertEqual(len(PyCodeCache.cache.keys()), 0)
845            res1 = torch.compile(fn)(a, b)
846            cache_dir1 = cache_dir()
847
848        torch._dynamo.reset()
849        with fresh_inductor_cache():
850            self.assertEqual(len(PyCodeCache.cache.keys()), 0)
851            res2 = torch.compile(fn)(a, b)
852            cache_dir2 = cache_dir()
853
854        self.assertEqual(res1, res2)
855        self.assertNotEqual(cache_dir1, cache_dir2)
856
857
858if __name__ == "__main__":
859    run_tests()
860