1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import tempfile 8import unittest 9from pathlib import Path 10 11import torch 12 13from executorch.extension.pybindings.test.make_test import ( 14 create_program, 15 ModuleAdd, 16 ModuleMulti, 17) 18from executorch.runtime import Runtime, Verification 19 20 21class RuntimeTest(unittest.TestCase): 22 def test_smoke(self): 23 ep, inputs = create_program(ModuleAdd()) 24 runtime = Runtime.get() 25 # Demonstrate that get() returns a singleton. 26 runtime2 = Runtime.get() 27 self.assertTrue(runtime is runtime2) 28 program = runtime.load_program(ep.buffer, verification=Verification.Minimal) 29 method = program.load_method("forward") 30 outputs = method.execute(inputs) 31 self.assertTrue(torch.allclose(outputs[0], inputs[0] + inputs[1])) 32 33 def test_module_with_multiple_method_names(self): 34 ep, inputs = create_program(ModuleMulti()) 35 runtime = Runtime.get() 36 37 program = runtime.load_program(ep.buffer, verification=Verification.Minimal) 38 self.assertEqual(program.method_names, set({"forward", "forward2"})) 39 method = program.load_method("forward") 40 outputs = method.execute(inputs) 41 self.assertTrue(torch.allclose(outputs[0], inputs[0] + inputs[1])) 42 43 method = program.load_method("forward2") 44 outputs = method.execute(inputs) 45 self.assertTrue(torch.allclose(outputs[0], inputs[0] + inputs[1] + 1)) 46 47 def test_print_operator_names(self): 48 ep, inputs = create_program(ModuleAdd()) 49 runtime = Runtime.get() 50 51 operator_names = runtime.operator_registry.operator_names 52 self.assertGreater(len(operator_names), 0) 53 54 self.assertIn("aten::add.out", operator_names) 55 56 def test_load_program_with_path(self): 57 ep, inputs = create_program(ModuleAdd()) 58 runtime = Runtime.get() 59 60 def test_add(program): 61 method = program.load_method("forward") 62 outputs = method.execute(inputs) 63 self.assertTrue(torch.allclose(outputs[0], inputs[0] + inputs[1])) 64 65 with tempfile.NamedTemporaryFile() as f: 66 f.write(ep.buffer) 67 f.flush() 68 # filename 69 program = runtime.load_program(f.name) 70 test_add(program) 71 # pathlib.Path 72 path = Path(f.name) 73 program = runtime.load_program(path) 74 test_add(program) 75 # BytesIO 76 with open(f.name, "rb") as f: 77 program = runtime.load_program(f.read()) 78 test_add(program) 79