1# Owner(s): ["oncall: quantization"] 2 3import re 4import contextlib 5from pathlib import Path 6 7import torch 8 9# import torch.ao.nn.quantized as nnq 10from torch.testing._internal.common_quantization import ( 11 QuantizationTestCase, 12 SingleLayerLinearModel, 13) 14from torch.testing._internal.common_quantized import override_quantized_engine 15from torch.testing._internal.common_utils import IS_ARM64 16 17 18class TestQuantizationDocs(QuantizationTestCase): 19 r""" 20 The tests in this section import code from the quantization docs and check that 21 they actually run without errors. In cases where objects are undefined in the code snippet, 22 they must be provided in the test. The imports seem to behave a bit inconsistently, 23 they can be imported either in the test file or passed as a global input 24 """ 25 26 def run(self, result=None): 27 with override_quantized_engine("qnnpack") if IS_ARM64 else contextlib.nullcontext(): 28 super().run(result) 29 30 def _get_code( 31 self, path_from_pytorch, unique_identifier, offset=2, short_snippet=False 32 ): 33 r""" 34 This function reads in the code from the docs given a unique identifier. 35 Most code snippets have a 2 space indentation, for other indentation levels, 36 change the offset `arg`. the `short_snippet` arg can be set to allow for testing 37 of smaller snippets, the check that this arg controls is used to make sure that 38 we are not accidentally only importing a blank line or something. 39 """ 40 41 def get_correct_path(path_from_pytorch): 42 r""" 43 Current working directory when CI is running test seems to vary, this function 44 looks for docs relative to this test file. 45 """ 46 core_dir = Path(__file__).parent 47 assert core_dir.match("test/quantization/core/"), ( 48 "test_docs.py is in an unexpected location. If you've been " 49 "moving files around, ensure that the test and build files have " 50 "been updated to have the correct relative path between " 51 "test_docs.py and the docs." 52 ) 53 pytorch_root = core_dir.parent.parent.parent 54 return pytorch_root / path_from_pytorch 55 56 path_to_file = get_correct_path(path_from_pytorch) 57 if path_to_file: 58 with open(path_to_file) as file: 59 content = file.readlines() 60 61 # it will register as having a newline at the end in python 62 if "\n" not in unique_identifier: 63 unique_identifier += "\n" 64 65 assert unique_identifier in content, f"could not find {unique_identifier} in {path_to_file}" 66 67 # get index of first line of code 68 line_num_start = content.index(unique_identifier) + 1 69 70 # next find where the code chunk ends. 71 # this regex will match lines that don't start 72 # with a \n or " " with number of spaces=offset 73 r = r = re.compile("^[^\n," + " " * offset + "]") 74 # this will return the line of first line that matches regex 75 line_after_code = next(filter(r.match, content[line_num_start:])) 76 last_line_num = content.index(line_after_code) 77 78 # remove the first `offset` chars of each line and gather it all together 79 code = "".join( 80 [x[offset:] for x in content[line_num_start + 1 : last_line_num]] 81 ) 82 83 # want to make sure we are actually getting some code, 84 assert last_line_num - line_num_start > 3 or short_snippet, ( 85 f"The code in {path_to_file} identified by {unique_identifier} seems suspiciously short:" 86 f"\n\n###code-start####\n{code}###code-end####" 87 ) 88 return code 89 90 return None 91 92 def _test_code(self, code, global_inputs=None): 93 r""" 94 This function runs `code` using any vars in `global_inputs` 95 """ 96 # if couldn't find the 97 if code is not None: 98 expr = compile(code, "test", "exec") 99 exec(expr, global_inputs) 100 101 def test_quantization_doc_ptdq(self): 102 path_from_pytorch = "docs/source/quantization.rst" 103 unique_identifier = "PTDQ API Example::" 104 code = self._get_code(path_from_pytorch, unique_identifier) 105 self._test_code(code) 106 107 def test_quantization_doc_ptsq(self): 108 path_from_pytorch = "docs/source/quantization.rst" 109 unique_identifier = "PTSQ API Example::" 110 code = self._get_code(path_from_pytorch, unique_identifier) 111 self._test_code(code) 112 113 def test_quantization_doc_qat(self): 114 path_from_pytorch = "docs/source/quantization.rst" 115 unique_identifier = "QAT API Example::" 116 117 def _dummy_func(*args, **kwargs): 118 return None 119 120 input_fp32 = torch.randn(1, 1, 1, 1) 121 global_inputs = {"training_loop": _dummy_func, "input_fp32": input_fp32} 122 code = self._get_code(path_from_pytorch, unique_identifier) 123 self._test_code(code, global_inputs) 124 125 def test_quantization_doc_fx(self): 126 path_from_pytorch = "docs/source/quantization.rst" 127 unique_identifier = "FXPTQ API Example::" 128 129 input_fp32 = SingleLayerLinearModel().get_example_inputs() 130 global_inputs = {"UserModel": SingleLayerLinearModel, "input_fp32": input_fp32} 131 132 code = self._get_code(path_from_pytorch, unique_identifier) 133 self._test_code(code, global_inputs) 134 135 def test_quantization_doc_custom(self): 136 path_from_pytorch = "docs/source/quantization.rst" 137 unique_identifier = "Custom API Example::" 138 139 global_inputs = {"nnq": torch.ao.nn.quantized} 140 141 code = self._get_code(path_from_pytorch, unique_identifier) 142 self._test_code(code, global_inputs) 143