1#!/usr/bin/env python3 2# Copyright (c) Meta Platforms, Inc. and affiliates. 3# All rights reserved. 4# 5# This source code is licensed under the BSD-style license found in the 6# LICENSE file in the root directory of this source tree. 7 8""" 9This script is run by CI after building the executorch wheel. Before running 10this, the job will install the matching torch package as well as the newly-built 11executorch package and its dependencies. 12""" 13 14# Import this first. If it can't find the torch.so libraries, the dynamic load 15# will fail and the process will exit. 16from executorch.extension.pybindings import portable_lib # usort: skip 17 18# Import custom ops. This requires portable_lib to be loaded first. 19from executorch.extension.llm.custom_ops import ( # noqa: F401, F403 20 sdpa_with_kv_cache, 21) # usort: skip 22 23# Import quantized ops. This requires portable_lib to be loaded first. 24from executorch.kernels import quantized # usort: skip # noqa: F401, F403 25 26# Import this after importing the ExecuTorch pybindings. If the pybindings 27# links against a different torch.so than this uses, there will be a set of 28# symbol comflicts; the process will either exit now, or there will be issues 29# later in the smoke test. 30import torch # usort: skip 31 32# Import everything else later to help isolate the critical imports above. 33import os 34import tempfile 35from typing import Tuple 36 37from executorch.exir import to_edge 38from torch.export import export 39 40 41class LinearModel(torch.nn.Module): 42 """Runs Linear on its input, which should have shape [4].""" 43 44 def __init__(self): 45 super().__init__() 46 self.linear = torch.nn.Linear(4, 2) 47 48 def forward(self, x: torch.Tensor): 49 """Expects a single tensor of shape [4].""" 50 return self.linear(x) 51 52 53def linear_model_inputs() -> Tuple[torch.Tensor]: 54 """Returns some example inputs compatible with LinearModel.""" 55 # The model takes a single tensor of shape [4] as an input. 56 return (torch.ones(4),) 57 58 59def export_linear_model() -> bytes: 60 """Exports LinearModel and returns the .pte data.""" 61 62 # This helps the exporter understand the shapes of tensors used in the model. 63 # Since our model only takes one input, this is a one-tuple. 64 example_inputs = linear_model_inputs() 65 66 # Export the pytorch model and process for ExecuTorch. 67 print("Exporting program...") 68 exported_program = export(LinearModel(), example_inputs) 69 print("Lowering to edge...") 70 edge_program = to_edge(exported_program) 71 print("Creating ExecuTorch program...") 72 et_program = edge_program.to_executorch() 73 74 return et_program.buffer 75 76 77def main(): 78 """Tests the export and execution of a simple model.""" 79 80 # If the pybindings loaded correctly, we should be able to ask for the set 81 # of operators. 82 ops = portable_lib._get_operator_names() 83 assert len(ops) > 0, "Empty operator list" 84 print(f"Found {len(ops)} operators; first element '{ops[0]}'") 85 86 # Make sure custom ops are registered. 87 assert ( 88 "llama::sdpa_with_kv_cache.out" in ops 89 ), f"llama::sdpa_with_kv_cache.out not registered, Got ops: {ops}" 90 91 # Make sure quantized ops are registered. 92 assert ( 93 "quantized_decomposed::add.out" in ops 94 ), f"quantized_decomposed::add.out not registered, Got ops: {ops}" 95 # Export LinearModel to .pte data. 96 pte_data: bytes = export_linear_model() 97 98 # Try saving to and loading from a file. 99 with tempfile.TemporaryDirectory() as tempdir: 100 pte_file = os.path.join(tempdir, "linear.pte") 101 102 # Save the .pte data to a file. 103 with open(pte_file, "wb") as file: 104 file.write(pte_data) 105 print(f"ExecuTorch program saved to {pte_file} ({len(pte_data)} bytes).") 106 107 # Load the model from disk. 108 m = portable_lib._load_for_executorch(pte_file) 109 110 # Run the model. 111 outputs = m.forward(linear_model_inputs()) 112 113 # Should see a single output with shape [2]. 114 assert len(outputs) == 1, f"Unexpected output length {len(outputs)}: {outputs}" 115 assert outputs[0].shape == (2,), f"Unexpected output size {outputs[0].shape}" 116 117 print("PASS") 118 119 120if __name__ == "__main__": 121 main() 122