1# 2# Copyright (c) 2024 Apple Inc. All rights reserved. 3# Provided subject to the LICENSE file in the top level directory. 4# 5 6import logging 7import time 8from typing import Tuple 9 10import torch 11from executorch.backends.apple.mps.test.test_mps_utils import TestMPS 12from torch.export.exported_program import ExportedProgram 13 14 15def bench_forward(func, *args): 16 # warmup 17 for _ in range(10): 18 func(*args) 19 20 start = time.time() 21 for _ in range(100): 22 func(*args) 23 end = time.time() 24 return end - start 25 26 27def executorch_forward_pass(model, inputs): 28 for _ in range(10): 29 model.forward(inputs) 30 31 32def synchronize(): 33 torch.mps.synchronize() 34 35 36def pytorch_forward_pass(model, inputs): 37 for _ in range(10): 38 model(*inputs) 39 synchronize() 40 41 42def get_mps_inputs(inputs): 43 inputs_mps = [] 44 for tensor in inputs: 45 inputs_mps.append(tensor.to("mps")) 46 inputs_mps = tuple(inputs_mps) 47 return inputs_mps 48 49 50def get_executorch_model(executorch_program: ExportedProgram): 51 try: 52 from executorch.extension.pybindings.portable_lib import ( # @manual 53 _load_for_executorch_from_buffer, 54 ) 55 56 return _load_for_executorch_from_buffer(executorch_program.buffer) 57 except ImportError: 58 logging.info( 59 "ExecuTorch MPS delegate was built without pybind support (not possible to run forward pass within python)" 60 ) 61 return None 62 63 64def bench_torch(executorch_program: ExportedProgram, model, inputs, model_name): 65 model = model.to("mps") 66 inputs_mps = get_mps_inputs(inputs) 67 68 executorch_model = get_executorch_model(executorch_program) 69 if executorch_model is not None: 70 t_pytorch = bench_forward(pytorch_forward_pass, model, inputs_mps) 71 t_executorch = bench_forward(executorch_forward_pass, executorch_model, inputs) 72 73 logging.info(f"Model name: {model_name}") 74 logging.info(f"Pytorch MPS forward pass: {t_pytorch} seconds") 75 logging.info(f"ExecuTorch MPS forward pass: {t_executorch} seconds") 76 logging.info( 77 f"ExecuTorch speedup: {((t_pytorch - t_executorch) / t_pytorch) * 100}%" 78 ) 79 80 81def compare_outputs( 82 executorch_program: ExportedProgram, 83 model: torch.nn.Module, 84 inputs: Tuple[torch.tensor], 85 model_name: str, 86 use_fp16: bool, 87): 88 test_module = TestMPS() 89 inputs_copy = [] 90 if use_fp16: 91 model = model.to(torch.float16) 92 model = model 93 for t in inputs: 94 tensor = t.detach().clone() 95 if use_fp16 and tensor.dtype == torch.float32: 96 tensor = tensor.to(torch.float16) 97 inputs_copy.append(tensor) 98 inputs_copy = tuple(inputs_copy) 99 100 pytorch_results = model(*inputs_copy) 101 102 executorch_model = get_executorch_model(executorch_program) 103 if executorch_model is not None: 104 executorch_results = executorch_model.forward(inputs) 105 test_module.assert_outputs_equal(executorch_results, pytorch_results, use_fp16) 106 logging.info( 107 f"Results between ExecuTorch forward pass with MPS backend and PyTorch forward pass for {model_name} are matching!" 108 ) 109