# Owner(s): ["module: inductor"] import os import unittest import torch import torch._inductor.config as inductor_config from torch._inductor.autoheuristic.autoheuristic import AutoHeuristic, LocalFeedback from torch._inductor.autoheuristic.autoheuristic_utils import AHContext from torch._inductor.runtime.runtime_utils import cache_dir from torch._inductor.test_case import run_tests, TestCase from torch._inductor.utils import get_gpu_shared_memory from torch.testing._internal.inductor_utils import HAS_CUDA, IS_A100, IS_H100 class AutoHeuristicTest(TestCase): def count_lines_in_file(self, file_path): with open(file_path) as file: line_count = sum(1 for line in file) return line_count def run_mm(self): def f(a, b): return torch.mm(a, b) cf = torch.compile(f) a = torch.randn(2047, 2048, device="cuda", dtype=torch.float16) b = torch.randn(2048, 2048, device="cuda", dtype=torch.float16) cf(a, b) def get_path_to_autoheuristic_log(self, name): device_name = AutoHeuristic.get_device_identifier() path = cache_dir() + "/autoheuristic/" + device_name + "/" + name + ".txt" return path def test_autoheuristic_pad_mm_default(self): # this test ensures that data is not collected for pad_mm when autoheuristic config is set to its default value self.run_mm() self.assertFalse(os.path.exists(self.get_path_to_autoheuristic_log("pad_mm"))) @inductor_config.patch(autoheuristic_collect="foo") def test_autoheuristic_pad_mm_off(self): # this test ensures that data is not collected for pad_mm when autoheuristic_collect does not contain "pad_mm" self.run_mm() self.assertFalse(os.path.exists(self.get_path_to_autoheuristic_log("pad_mm"))) def assert_autoheuristic_collected_data(self): self.run_mm() device_name = AutoHeuristic.get_device_identifier() path = self.get_path_to_autoheuristic_log("pad_mm") self.assertTrue(os.path.exists(path)) num_lines = self.count_lines_in_file(path) # 1 line for metadata, 1 line for header, 1 line per choice (orig, padded) self.assertEqual(num_lines, 4) @inductor_config.patch(autoheuristic_collect="pad_mm") def test_autoheuristic_pad_mm_collect_data(self): # this test ensures that data is collected for pad_mm when autoheuristic_collect="pad_mm" self.assert_autoheuristic_collected_data() @inductor_config.patch(autoheuristic_collect="foo,pad_mm") def test_autoheuristic_pad_mm_collect_data2(self): # this test ensures that data is collected for "pad_mm" when autoheuristic_collect contains "pad_mm" self.assert_autoheuristic_collected_data() @inductor_config.patch(autoheuristic_collect="test") def test_autoheuristic(self): # test basic functionality of autoheuristic def fallback(): return "fallback" choices = ["a", "b", "c"] def feedback_fn(choice): if choice == "a": return 1 elif choice == "b": return 2 elif choice == "c": return 3 else: raise RuntimeError("unexpected choice") feedback = LocalFeedback(feedback_fn) context = AHContext() context.add_feature("fa", 5) name = "test" autoheuristic = AutoHeuristic(fallback, choices, feedback, context, name) # when autoheuristic is configured to only collect data, we always return fallback self.assertEqual(autoheuristic.get_choice(), "fallback") self.assertEqual(autoheuristic.get_collected_feedback("a"), 1) self.assertEqual(autoheuristic.get_collected_feedback("b"), 2) self.assertEqual(autoheuristic.get_collected_feedback("c"), 3) path = self.get_path_to_autoheuristic_log(name) self.assertTrue(os.path.exists(path)) num_lines = self.count_lines_in_file(path) self.assertEqual(num_lines, 5) shared_memory = get_gpu_shared_memory() (fst, snd) = torch.cuda.get_device_capability() with open(path) as file: lines = file.readlines() self.assertTrue('"numerical_features": ["fa"]' in lines[0]) self.assertTrue('"categorical_features": []' in lines[0]) self.assertTrue(f'"shared_memory": {shared_memory}' in lines[0]) self.assertTrue(f'"device_capa": [{fst}, {snd}]' in lines[0]) self.assertTrue('"name": "test"' in lines[0]) self.assertEqual("fa,choice,feedback", lines[1].rstrip()) self.assertEqual("5,a,1", lines[2].rstrip()) self.assertEqual("5,b,2", lines[3].rstrip()) self.assertEqual("5,c,3", lines[4].rstrip()) @unittest.skipIf(not IS_A100, "heuristic only run on A100") @inductor_config.patch(autoheuristic_use="pad_mm") def test_autoheuristic_a100(self): # Make sure heuristic does not break anything # TODO (AlnisM): Find a way to check whether heuristic is used self.run_mm() @unittest.skipIf(not IS_H100, "heuristic only run on H100") @inductor_config.patch(autoheuristic_use="pad_mm") def test_autoheuristic_h100(self): # Make sure heuristic does not break anything # TODO (AlnisM): Find a way to check whether heuristic is used self.run_mm() def run_mixed_mm(self): def fn(a, b): return torch.mm(a, b.to(a.dtype)) a = torch.randn(8, 1024, device="cuda", dtype=torch.float16) b = torch.randint(-128, 127, (1024, 1024), dtype=torch.int8, device="cuda").t() torch.compile(fn, mode="max-autotune-no-cudagraphs")(a, b) # have to set autoheuristic_use="" because if autoheuristic_use="mixed_mm", # autoheuristic creates a precompile key, puts it into the registry, and then # a choice made by the heuristic might be added to the list of choices # and if select_algorithm now creates a new precompile key, it will be # different from the precompile key created by autoheuristic @inductor_config.patch( autoheuristic_collect="mixed_mm", autoheuristic_use="", fx_graph_cache=False, fx_graph_remote_cache=False, ) def test_global_feedback(self): self.run_mixed_mm() path = self.get_path_to_autoheuristic_log("mixed_mm") self.assertTrue(os.path.exists(path)) num_lines = self.count_lines_in_file(path) # 1 line for metadata, 1 line for header # 1 line for fallback + at least 1 config self.assertTrue(num_lines > 4) @inductor_config.patch(autoheuristic_use="mixed_mm") @unittest.skipIf(not IS_A100, "heuristic only run on A100") def test_mixed_mm_a100(self): self.run_mixed_mm() # TODO (AlnisM): Find a way to check whether heuristic is used if __name__ == "__main__": if HAS_CUDA: run_tests()