xref: /aosp_15_r20/external/executorch/runtime/test/test_runtime.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import tempfile
8import unittest
9from pathlib import Path
10
11import torch
12
13from executorch.extension.pybindings.test.make_test import (
14    create_program,
15    ModuleAdd,
16    ModuleMulti,
17)
18from executorch.runtime import Runtime, Verification
19
20
21class RuntimeTest(unittest.TestCase):
22    def test_smoke(self):
23        ep, inputs = create_program(ModuleAdd())
24        runtime = Runtime.get()
25        # Demonstrate that get() returns a singleton.
26        runtime2 = Runtime.get()
27        self.assertTrue(runtime is runtime2)
28        program = runtime.load_program(ep.buffer, verification=Verification.Minimal)
29        method = program.load_method("forward")
30        outputs = method.execute(inputs)
31        self.assertTrue(torch.allclose(outputs[0], inputs[0] + inputs[1]))
32
33    def test_module_with_multiple_method_names(self):
34        ep, inputs = create_program(ModuleMulti())
35        runtime = Runtime.get()
36
37        program = runtime.load_program(ep.buffer, verification=Verification.Minimal)
38        self.assertEqual(program.method_names, set({"forward", "forward2"}))
39        method = program.load_method("forward")
40        outputs = method.execute(inputs)
41        self.assertTrue(torch.allclose(outputs[0], inputs[0] + inputs[1]))
42
43        method = program.load_method("forward2")
44        outputs = method.execute(inputs)
45        self.assertTrue(torch.allclose(outputs[0], inputs[0] + inputs[1] + 1))
46
47    def test_print_operator_names(self):
48        ep, inputs = create_program(ModuleAdd())
49        runtime = Runtime.get()
50
51        operator_names = runtime.operator_registry.operator_names
52        self.assertGreater(len(operator_names), 0)
53
54        self.assertIn("aten::add.out", operator_names)
55
56    def test_load_program_with_path(self):
57        ep, inputs = create_program(ModuleAdd())
58        runtime = Runtime.get()
59
60        def test_add(program):
61            method = program.load_method("forward")
62            outputs = method.execute(inputs)
63            self.assertTrue(torch.allclose(outputs[0], inputs[0] + inputs[1]))
64
65        with tempfile.NamedTemporaryFile() as f:
66            f.write(ep.buffer)
67            f.flush()
68            # filename
69            program = runtime.load_program(f.name)
70            test_add(program)
71            # pathlib.Path
72            path = Path(f.name)
73            program = runtime.load_program(path)
74            test_add(program)
75            # BytesIO
76            with open(f.name, "rb") as f:
77                program = runtime.load_program(f.read())
78                test_add(program)
79