xref: /aosp_15_r20/external/executorch/examples/apple/mps/scripts/bench_utils.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1#
2#  Copyright (c) 2024 Apple Inc. All rights reserved.
3#  Provided subject to the LICENSE file in the top level directory.
4#
5
6import logging
7import time
8from typing import Tuple
9
10import torch
11from executorch.backends.apple.mps.test.test_mps_utils import TestMPS
12from torch.export.exported_program import ExportedProgram
13
14
15def bench_forward(func, *args):
16    # warmup
17    for _ in range(10):
18        func(*args)
19
20    start = time.time()
21    for _ in range(100):
22        func(*args)
23    end = time.time()
24    return end - start
25
26
27def executorch_forward_pass(model, inputs):
28    for _ in range(10):
29        model.forward(inputs)
30
31
32def synchronize():
33    torch.mps.synchronize()
34
35
36def pytorch_forward_pass(model, inputs):
37    for _ in range(10):
38        model(*inputs)
39    synchronize()
40
41
42def get_mps_inputs(inputs):
43    inputs_mps = []
44    for tensor in inputs:
45        inputs_mps.append(tensor.to("mps"))
46    inputs_mps = tuple(inputs_mps)
47    return inputs_mps
48
49
50def get_executorch_model(executorch_program: ExportedProgram):
51    try:
52        from executorch.extension.pybindings.portable_lib import (  # @manual
53            _load_for_executorch_from_buffer,
54        )
55
56        return _load_for_executorch_from_buffer(executorch_program.buffer)
57    except ImportError:
58        logging.info(
59            "ExecuTorch MPS delegate was built without pybind support (not possible to run forward pass within python)"
60        )
61        return None
62
63
64def bench_torch(executorch_program: ExportedProgram, model, inputs, model_name):
65    model = model.to("mps")
66    inputs_mps = get_mps_inputs(inputs)
67
68    executorch_model = get_executorch_model(executorch_program)
69    if executorch_model is not None:
70        t_pytorch = bench_forward(pytorch_forward_pass, model, inputs_mps)
71        t_executorch = bench_forward(executorch_forward_pass, executorch_model, inputs)
72
73        logging.info(f"Model name: {model_name}")
74        logging.info(f"Pytorch MPS forward pass: {t_pytorch} seconds")
75        logging.info(f"ExecuTorch MPS forward pass: {t_executorch} seconds")
76        logging.info(
77            f"ExecuTorch speedup: {((t_pytorch - t_executorch) / t_pytorch) * 100}%"
78        )
79
80
81def compare_outputs(
82    executorch_program: ExportedProgram,
83    model: torch.nn.Module,
84    inputs: Tuple[torch.tensor],
85    model_name: str,
86    use_fp16: bool,
87):
88    test_module = TestMPS()
89    inputs_copy = []
90    if use_fp16:
91        model = model.to(torch.float16)
92    model = model
93    for t in inputs:
94        tensor = t.detach().clone()
95        if use_fp16 and tensor.dtype == torch.float32:
96            tensor = tensor.to(torch.float16)
97        inputs_copy.append(tensor)
98    inputs_copy = tuple(inputs_copy)
99
100    pytorch_results = model(*inputs_copy)
101
102    executorch_model = get_executorch_model(executorch_program)
103    if executorch_model is not None:
104        executorch_results = executorch_model.forward(inputs)
105        test_module.assert_outputs_equal(executorch_results, pytorch_results, use_fp16)
106        logging.info(
107            f"Results between ExecuTorch forward pass with MPS backend and PyTorch forward pass for {model_name} are matching!"
108        )
109