xref: /aosp_15_r20/external/executorch/backends/apple/mps/test/test_mps_utils.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1#
2#  Copyright (c) 2023 Apple Inc. All rights reserved.
3#  Provided subject to the LICENSE file in the top level directory.
4#
5
6import logging
7import unittest
8
9from typing import Any, Tuple
10
11import executorch.exir as exir
12import torch
13from executorch.backends.apple.mps import MPSBackend
14from executorch.backends.apple.mps.partition import MPSPartitioner
15from executorch.devtools import BundledProgram
16from executorch.devtools.bundled_program.config import MethodTestCase, MethodTestSuite
17from executorch.devtools.bundled_program.serialize import (
18    serialize_from_bundled_program_to_flatbuffer,
19)
20from executorch.exir import EdgeCompileConfig, ExirExportedProgram, to_edge
21from executorch.exir.backend.backend_api import to_backend
22from executorch.exir.backend.backend_details import CompileSpec
23from executorch.exir.capture._config import ExecutorchBackendConfig
24from executorch.extension.export_util.utils import export_to_edge
25from torch.export import export
26
27# Config for Capturing the weights, will be moved in the future
28
29# TODO(T182928844): Delegate dim order op to backend.
30_EDGE_COMPILE_CONFIG = exir.EdgeCompileConfig(
31    _check_ir_validity=False, _skip_dim_order=True
32)
33
34
35class ansi_colors:
36    HEADER = "\033[95m"
37    OKBLUE = "\033[94m"
38    OKCYAN = "\033[96m"
39    OKGREEN = "\033[92m"
40    WARNING = "\033[93m"
41    FAIL = "\033[91m"
42    ENDC = "\033[0m"
43    BOLD = "\033[1m"
44    UNDERLINE = "\033[4m"
45
46
47class OpSequencesAddConv2d(torch.nn.Module):
48    """
49    Module which include sequences of Memory Format sensitive ops. forward runs
50    [num_sequences] sequences of [ops_per_sequences] ops. Each sequence is
51    followed by an add to separate the sequences
52    """
53
54    def __init__(self, num_sequences, ops_per_sequence):
55        super().__init__()
56        self.num_ops = num_sequences * ops_per_sequence
57        self.num_sequences = num_sequences
58        self.op_sequence = [[] for _ in range(num_sequences)]
59        for seq in range(num_sequences):
60            for _ in range(ops_per_sequence):
61                self.op_sequence[seq].append(
62                    torch.nn.Conv2d(
63                        in_channels=1,
64                        out_channels=1,
65                        kernel_size=(3, 3),
66                        padding=1,
67                        bias=False,
68                    )
69                )
70
71    def forward(self, x):
72        for seq in self.op_sequence:
73            for op in seq:
74                x = op(x)
75            x = x + x
76        return x + x
77
78
79def randomize_bn(num_features: int, dimensionality: int = 2) -> torch.nn.Module:
80    if dimensionality == 1:
81        bn = torch.nn.BatchNorm1d(num_features)
82        input_size = (1, num_features, 5)
83    elif dimensionality == 2:
84        bn = torch.nn.BatchNorm2d(num_features)
85        input_size = (1, num_features, 5, 5)
86    else:
87        raise AssertionError(
88            f"Only dimensionality 1 or 2 supported in randomize_bn, got {dimensionality}"
89        )
90
91    bn.weight = torch.nn.Parameter(torch.randn(num_features))
92    bn.bias = torch.nn.Parameter(torch.randn(num_features))
93
94    for _ in range(5):
95        bn(torch.randn(size=input_size))
96
97    return bn
98
99
100def dump_bundled_program(sample_inputs, expected_output, executorch_program, func_name):
101    method_test_suites = [
102        MethodTestSuite(
103            method_name="forward",
104            test_cases=[
105                MethodTestCase(inputs=sample_inputs, expected_outputs=expected_output)
106            ],
107        )
108    ]
109
110    logging.info(f"Expected output: {expected_output}")
111    logging.info("  -> Test suites generated successfully")
112
113    bundled_program = BundledProgram(executorch_program, method_test_suites)
114    bundled_program_buffer = serialize_from_bundled_program_to_flatbuffer(
115        bundled_program
116    )
117
118    filename = f"{func_name}.pte"
119    logging.info(f"Step 4: Saving bundled program to {filename}")
120    with open(filename, "wb") as file:
121        file.write(bundled_program_buffer)
122
123
124class TestMPS(unittest.TestCase):
125    def assert_outputs_equal(
126        self,
127        model_output,
128        ref_output,
129        use_fp16: bool = False,
130        atol: float = 1e-03,
131        rtol: float = 1e-03,
132    ):
133        """
134        Helper testing function that asserts that the model output and the reference output
135        are equal with some tolerance. Due to numerical differences between eager mode and
136        the MPS's backend, we relax the detal such that absolute tolerance is 1e-3. and
137        relative tolerance is 1e-3.
138        """
139        # Compare the result from executor and eager mode directly
140        if isinstance(ref_output, tuple) or isinstance(ref_output, list):
141            # Multiple outputs executor always returns tuple, even if there is one output
142            assert len(ref_output) == len(
143                model_output
144            ), "Length of outputs is not matching!"
145            for i in range(len(ref_output)):
146                res_output = model_output[i].cpu()
147                expected_output = ref_output[i].cpu()
148                if use_fp16 and (
149                    expected_output.dtype == torch.float16
150                    or res_output.dtype == torch.float16
151                ):
152                    # cast back from fp16 to fp32 (ExecuTorch results are in FP32 by default)
153                    expected_output = expected_output.to(torch.float32)
154                    res_output = res_output.to(torch.float32)
155                if (
156                    torch.allclose(res_output, expected_output, atol=atol, rtol=rtol)
157                    is False
158                ):
159                    mean_err = (
160                        (res_output - expected_output).abs() / expected_output
161                    ).mean()
162                    logging.debug(f"mean err = {mean_err}")
163                    self.assertLess(mean_err, 0.05)
164        else:
165            # If one output, eager returns tensor while executor tuple of size 1
166            expected_output = ref_output.cpu()
167            res_output = model_output[0].cpu()
168            if use_fp16 and (
169                expected_output.dtype == torch.float16
170                or res_output.dtype == torch.float16
171            ):
172                # cast back from fp16 to fp32 (ExecuTorch results are in FP32 by default)
173                expected_output = expected_output.to(torch.float32)
174                res_output = res_output.to(torch.float32)
175            if (
176                torch.allclose(res_output, expected_output, atol=atol, rtol=rtol)
177                is False
178            ):
179                mean_err = (
180                    (res_output - expected_output).abs() / expected_output
181                ).mean()
182                logging.debug(f"mean err = {mean_err}")
183                self.assertLess(mean_err, 0.05)
184
185    def lower_module_and_test_output(
186        self,
187        module: Any,
188        sample_inputs: Tuple[torch.Tensor],
189        func_name: str,
190        use_partitioner: bool = True,
191        use_fp16: bool = False,
192        bundled_program=True,
193        dynamic_shapes=None,
194        atol: float = 1e-03,
195        rtol: float = 1e-03,
196    ) -> ExirExportedProgram:
197        """
198        Helper testing function that takes a torch.nn.Module and lowers it to MPS with
199        the given sample inputs. It then runs the lowered module and compares its
200        outputs with the outputs of the eager module.
201        """
202        logging.info("Step 1: EXIR capturing of original module")
203
204        model = module.eval()
205        original_inputs = []
206        for t in sample_inputs:
207            original_inputs.append(t.detach().clone())
208        original_inputs = tuple(original_inputs)
209
210        expected_output = model(*sample_inputs)
211
212        model = torch.export.export_for_training(
213            model, sample_inputs, dynamic_shapes=dynamic_shapes
214        ).module()
215
216        edge_program = export_to_edge(
217            model,
218            sample_inputs,
219            dynamic_shapes=dynamic_shapes,
220            edge_compile_config=EdgeCompileConfig(
221                _check_ir_validity=False,
222                _skip_dim_order=True,  # TODO(T182928844): Delegate dim order op to backend.
223            ),
224        )
225
226        logging.info(
227            f"Step 2: Lowering to MPSGraph {'with' if use_partitioner else 'without'} partitioner"
228        )
229        compile_specs = [CompileSpec("use_fp16", bytes([use_fp16]))]
230
231        if use_partitioner:
232            logging.info(f"Edge IR graph:\n{edge_program.exported_program()}")
233            delegated_program = edge_program
234            delegated_program = edge_program.to_backend(
235                MPSPartitioner(compile_specs=compile_specs)
236            )
237            logging.info(
238                f"Lowered graph:\n{delegated_program.exported_program().graph}"
239            )
240
241            executorch_program = delegated_program.to_executorch(
242                config=ExecutorchBackendConfig(extract_delegate_segments=False)
243            )
244        else:
245            delegated_program = to_backend(
246                MPSBackend.__name__, edge_program.exported_program(), compile_specs
247            )
248
249            executorch_program = to_edge(
250                export(
251                    delegated_program,
252                    sample_inputs,
253                ),
254                compile_config=exir.EdgeCompileConfig(
255                    _check_ir_validity=False,
256                    _skip_dim_order=True,  # TODO(T182928844): Delegate dim order op to backend.
257                ),
258            ).to_executorch(
259                config=ExecutorchBackendConfig(extract_delegate_segments=False)
260            )
261
262        if bundled_program:
263            dump_bundled_program(
264                sample_inputs, expected_output, executorch_program, func_name
265            )
266        try:
267            from executorch.extension.pybindings.portable_lib import (  # @manual
268                _load_for_executorch_from_buffer,
269            )
270
271            logging.info("Testing delegated program using pybind")
272
273            # Test the model with executor
274            logging.debug("Initializing MPSGraph")
275            executorch_module = _load_for_executorch_from_buffer(
276                executorch_program.buffer
277            )
278
279            model_output = executorch_module.forward(original_inputs)
280
281            logging.info(f"Expected output: {expected_output}")
282            logging.info(f"MPS delegate output: {model_output}")
283            self.assert_outputs_equal(model_output, expected_output, atol, rtol)
284            logging.info("Delegated program matches PyTorch Eager mode result!")
285
286            return delegated_program
287        except ImportError:
288            logging.info(
289                "ExecuTorch MPS delegate was built without pybind support. Exiting..."
290            )
291
292    def lower_and_test_with_partitioner(
293        self,
294        graph_module,
295        example_inputs,
296        func_name: str,
297        use_fp16: bool = False,
298        dynamic_shapes=None,
299        atol: float = 1e-03,
300        rtol: float = 1e-03,
301    ):
302        logging.info(func_name)
303        self.lower_module_and_test_output(
304            graph_module,
305            example_inputs,
306            use_partitioner=True,
307            func_name=func_name,
308            use_fp16=use_fp16,
309            dynamic_shapes=None,
310            atol=atol,
311            rtol=rtol,
312        )
313
314    def lower_and_test_without_partitioner(
315        self,
316        graph_module,
317        example_inputs,
318        func_name: str,
319        use_fp16: bool = False,
320        dynamic_shapes=None,
321        atol: float = 1e-03,
322        rtol: float = 1e-03,
323    ):
324        logging.info(func_name)
325        self.lower_module_and_test_output(
326            graph_module,
327            example_inputs,
328            use_partitioner=False,
329            func_name=func_name,
330            use_fp16=use_fp16,
331            dynamic_shapes=dynamic_shapes,
332            atol=atol,
333            rtol=rtol,
334        )
335