xref: /aosp_15_r20/external/executorch/runtime/__init__.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7"""
8Example usage:
9
10.. code-block:: python
11
12    from pathlib import Path
13
14    import torch
15    from executorch.runtime import Verification, Runtime, Program, Method
16
17    et_runtime: Runtime = Runtime.get()
18    program: Program = et_runtime.load_program(
19        Path("/tmp/program.pte"),
20        verification=Verification.Minimal,
21    )
22    print("Program methods:", program.method_names)
23    forward: Method = program.load_method("forward")
24
25    inputs = (torch.ones(2, 2), torch.ones(2, 2))
26    outputs = forward.execute(inputs)
27    print(f"Ran forward({inputs})")
28    print(f"  outputs: {outputs}")
29
30Example output:
31
32.. code-block:: text
33
34    Program methods: ('forward', 'forward2')
35    Ran forward((tensor([[1., 1.],
36            [1., 1.]]), tensor([[1., 1.],
37            [1., 1.]])))
38      outputs: [tensor([[1., 1.],
39            [1., 1.]])]
40"""
41
42import functools
43from pathlib import Path
44from types import ModuleType
45from typing import Any, BinaryIO, Dict, Optional, Sequence, Set, Union
46
47try:
48    from executorch.extension.pybindings.portable_lib import (
49        ExecuTorchModule,
50        MethodMeta,
51        Verification,
52    )
53except ModuleNotFoundError as e:
54    raise ModuleNotFoundError(
55        "Prebuilt <site-packages>/extension/pybindings/_portable_lib.so "
56        "is not found. Please reinstall ExecuTorch from pip."
57    ) from e
58
59
60class Method:
61    """An ExecuTorch method, loaded from a Program.
62    This can be used to execute the method with inputs.
63    """
64
65    def __init__(self, method_name: str, module: ExecuTorchModule) -> None:
66        # TODO: This class should be pybind to the C++ counterpart instead of hosting ExecuTorchModule.
67        self._method_name = method_name
68        self._module = module
69
70    def execute(self, inputs: Sequence[Any]) -> Sequence[Any]:
71        """Executes the method with the given inputs.
72
73        Args:
74            inputs: The inputs to the method.
75
76        Returns:
77            The outputs of the method.
78        """
79        return self._module.run_method(self._method_name, inputs)
80
81    @property
82    def metadata(self) -> MethodMeta:
83        """Gets the metadata for the method.
84
85        Returns:
86            The metadata for the method.
87        """
88        return self._module.method_meta(self._method_name)
89
90
91class Program:
92    """An ExecuTorch program, loaded from binary PTE data.
93
94    This can be used to load the methods/models defined by the program.
95    """
96
97    def __init__(self, module: ExecuTorchModule, data: Optional[bytes]) -> None:
98        # Hold the data so the program is not freed.
99        self._data = data
100        self._module = module
101        self._methods: Dict[str, Method] = {}
102        # ExecuTorchModule already pre-loads all Methods when created, so this
103        # doesn't do any extra work. TODO: Don't load a given Method until
104        # load_method() is called. Create a separate Method instance each time,
105        # to allow multiple independent instances of the same model.
106        for method_name in self._module.method_names():
107            self._methods[method_name] = Method(method_name, self._module)
108
109    @property
110    def method_names(self) -> Set[str]:
111        """
112        Returns method names of the `Program` as a set of strings.
113        """
114        return set(self._methods.keys())
115
116    def load_method(self, name: str) -> Optional[Method]:
117        """Loads a method from the program.
118
119        Args:
120            name: The name of the method to load.
121
122        Returns:
123            The loaded method.
124        """
125        return self._methods.get(name, None)
126
127
128class OperatorRegistry:
129    """The registry of operators that are available to the runtime."""
130
131    def __init__(self, legacy_module: ModuleType) -> None:
132        # TODO: Expose the kernel callables to Python.
133        self._legacy_module = legacy_module
134
135    @property
136    def operator_names(self) -> Set[str]:
137        """
138        Returns the names of all registered operators as a set of strings.
139        """
140        return set(self._legacy_module._get_operator_names())
141
142
143class Runtime:
144    """An instance of the ExecuTorch runtime environment.
145
146    This can be used to concurrently load and execute any number of ExecuTorch
147    programs and methods.
148    """
149
150    @staticmethod
151    @functools.lru_cache(maxsize=1)
152    def get() -> "Runtime":
153        """Gets the Runtime singleton."""
154        import executorch.extension.pybindings.portable_lib as legacy_module
155
156        return Runtime(legacy_module=legacy_module)
157
158    def __init__(self, *, legacy_module: ModuleType) -> None:
159        # Public attributes.
160        self.operator_registry = OperatorRegistry(legacy_module)
161        # Private attributes.
162        self._legacy_module = legacy_module
163
164    def load_program(
165        self,
166        data: Union[bytes, bytearray, BinaryIO, Path, str],
167        *,
168        verification: Verification = Verification.InternalConsistency,
169    ) -> Program:
170        """Loads an ExecuTorch program from a PTE binary.
171
172        Args:
173            data: The binary program data to load; typically PTE data.
174            verification: level of program verification to perform.
175
176        Returns:
177            The loaded program.
178        """
179        if isinstance(data, (Path, str)):
180            m = self._legacy_module._load_for_executorch(
181                str(data),
182                enable_etdump=False,
183                debug_buffer_size=0,
184                program_verification=verification,
185            )
186            return Program(m, data=None)
187        elif isinstance(data, BinaryIO):
188            data_bytes = data.read()
189        elif isinstance(data, bytearray):
190            data_bytes = bytes(data)
191        elif isinstance(data, bytes):
192            data_bytes = data
193        else:
194            raise TypeError(
195                f"Expected data to be bytes, bytearray, a path to a .pte file, or a file-like object, but got {type(data).__name__}."
196            )
197        m = self._legacy_module._load_for_executorch_from_buffer(
198            data_bytes,
199            enable_etdump=False,
200            debug_buffer_size=0,
201            program_verification=verification,
202        )
203
204        return Program(m, data=data_bytes)
205