xref: /aosp_15_r20/external/pytorch/test/quantization/core/test_docs.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: quantization"]
2
3import re
4import contextlib
5from pathlib import Path
6
7import torch
8
9# import torch.ao.nn.quantized as nnq
10from torch.testing._internal.common_quantization import (
11    QuantizationTestCase,
12    SingleLayerLinearModel,
13)
14from torch.testing._internal.common_quantized import override_quantized_engine
15from torch.testing._internal.common_utils import IS_ARM64
16
17
18class TestQuantizationDocs(QuantizationTestCase):
19    r"""
20    The tests in this section import code from the quantization docs and check that
21    they actually run without errors. In cases where objects are undefined in the code snippet,
22    they must be provided in the test. The imports seem to behave a bit inconsistently,
23    they can be imported either in the test file or passed as a global input
24    """
25
26    def run(self, result=None):
27        with override_quantized_engine("qnnpack") if IS_ARM64 else contextlib.nullcontext():
28            super().run(result)
29
30    def _get_code(
31        self, path_from_pytorch, unique_identifier, offset=2, short_snippet=False
32    ):
33        r"""
34        This function reads in the code from the docs given a unique identifier.
35        Most code snippets have a 2 space indentation, for other indentation levels,
36        change the offset `arg`. the `short_snippet` arg can be set to allow for testing
37        of smaller snippets, the check that this arg controls is used to make sure that
38        we are not accidentally only importing a blank line or something.
39        """
40
41        def get_correct_path(path_from_pytorch):
42            r"""
43            Current working directory when CI is running test seems to vary, this function
44            looks for docs relative to this test file.
45            """
46            core_dir = Path(__file__).parent
47            assert core_dir.match("test/quantization/core/"), (
48                "test_docs.py is in an unexpected location. If you've been "
49                "moving files around, ensure that the test and build files have "
50                "been updated to have the correct relative path between "
51                "test_docs.py and the docs."
52            )
53            pytorch_root = core_dir.parent.parent.parent
54            return pytorch_root / path_from_pytorch
55
56        path_to_file = get_correct_path(path_from_pytorch)
57        if path_to_file:
58            with open(path_to_file) as file:
59                content = file.readlines()
60
61            # it will register as having a newline at the end in python
62            if "\n" not in unique_identifier:
63                unique_identifier += "\n"
64
65            assert unique_identifier in content, f"could not find {unique_identifier} in {path_to_file}"
66
67            # get index of first line of code
68            line_num_start = content.index(unique_identifier) + 1
69
70            # next find where the code chunk ends.
71            # this regex will match lines that don't start
72            # with a \n or "  " with number of spaces=offset
73            r = r = re.compile("^[^\n," + " " * offset + "]")
74            # this will return the line of first line that matches regex
75            line_after_code = next(filter(r.match, content[line_num_start:]))
76            last_line_num = content.index(line_after_code)
77
78            # remove the first `offset` chars of each line and gather it all together
79            code = "".join(
80                [x[offset:] for x in content[line_num_start + 1 : last_line_num]]
81            )
82
83            # want to make sure we are actually getting some code,
84            assert last_line_num - line_num_start > 3 or short_snippet, (
85                f"The code in {path_to_file} identified by {unique_identifier} seems suspiciously short:"
86                f"\n\n###code-start####\n{code}###code-end####"
87            )
88            return code
89
90        return None
91
92    def _test_code(self, code, global_inputs=None):
93        r"""
94        This function runs `code` using any vars in `global_inputs`
95        """
96        # if couldn't find the
97        if code is not None:
98            expr = compile(code, "test", "exec")
99            exec(expr, global_inputs)
100
101    def test_quantization_doc_ptdq(self):
102        path_from_pytorch = "docs/source/quantization.rst"
103        unique_identifier = "PTDQ API Example::"
104        code = self._get_code(path_from_pytorch, unique_identifier)
105        self._test_code(code)
106
107    def test_quantization_doc_ptsq(self):
108        path_from_pytorch = "docs/source/quantization.rst"
109        unique_identifier = "PTSQ API Example::"
110        code = self._get_code(path_from_pytorch, unique_identifier)
111        self._test_code(code)
112
113    def test_quantization_doc_qat(self):
114        path_from_pytorch = "docs/source/quantization.rst"
115        unique_identifier = "QAT API Example::"
116
117        def _dummy_func(*args, **kwargs):
118            return None
119
120        input_fp32 = torch.randn(1, 1, 1, 1)
121        global_inputs = {"training_loop": _dummy_func, "input_fp32": input_fp32}
122        code = self._get_code(path_from_pytorch, unique_identifier)
123        self._test_code(code, global_inputs)
124
125    def test_quantization_doc_fx(self):
126        path_from_pytorch = "docs/source/quantization.rst"
127        unique_identifier = "FXPTQ API Example::"
128
129        input_fp32 = SingleLayerLinearModel().get_example_inputs()
130        global_inputs = {"UserModel": SingleLayerLinearModel, "input_fp32": input_fp32}
131
132        code = self._get_code(path_from_pytorch, unique_identifier)
133        self._test_code(code, global_inputs)
134
135    def test_quantization_doc_custom(self):
136        path_from_pytorch = "docs/source/quantization.rst"
137        unique_identifier = "Custom API Example::"
138
139        global_inputs = {"nnq": torch.ao.nn.quantized}
140
141        code = self._get_code(path_from_pytorch, unique_identifier)
142        self._test_code(code, global_inputs)
143