xref: /aosp_15_r20/external/pytorch/torch/jit/mobile/__init__.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import os
3
4import torch
5from torch.jit._serialization import validate_map_location
6
7
8def _load_for_lite_interpreter(f, map_location=None):
9    r"""
10    Load a :class:`LiteScriptModule` saved with :func:`torch.jit._save_for_lite_interpreter`.
11
12    Args:
13        f: a file-like object (has to implement read, readline, tell, and seek),
14            or a string containing a file name
15        map_location: a string or torch.device used to dynamically remap
16            storages to an alternative set of devices.
17
18    Returns:
19        A :class:`LiteScriptModule` object.
20
21    Example:
22
23    .. testcode::
24
25        import torch
26        import io
27
28        # Load LiteScriptModule from saved file path
29        torch.jit._load_for_lite_interpreter('lite_script_module.pt')
30
31        # Load LiteScriptModule from io.BytesIO object
32        with open('lite_script_module.pt', 'rb') as f:
33            buffer = io.BytesIO(f.read())
34
35        # Load all tensors to the original device
36        torch.jit.mobile._load_for_lite_interpreter(buffer)
37    """
38    if isinstance(f, (str, os.PathLike)):
39        if not os.path.exists(f):
40            raise ValueError(f"The provided filename {f} does not exist")
41        if os.path.isdir(f):
42            raise ValueError(f"The provided filename {f} is a directory")
43
44    map_location = validate_map_location(map_location)
45
46    if isinstance(f, (str, os.PathLike)):
47        cpp_module = torch._C._load_for_lite_interpreter(os.fspath(f), map_location)
48    else:
49        cpp_module = torch._C._load_for_lite_interpreter_from_buffer(
50            f.read(), map_location
51        )
52
53    return LiteScriptModule(cpp_module)
54
55
56class LiteScriptModule:
57    def __init__(self, cpp_module):
58        self._c = cpp_module
59        super().__init__()
60
61    def __call__(self, *input):
62        return self._c.forward(input)
63
64    def find_method(self, method_name):
65        return self._c.find_method(method_name)
66
67    def forward(self, *input):
68        return self._c.forward(input)
69
70    def run_method(self, method_name, *input):
71        return self._c.run_method(method_name, input)
72
73
74def _export_operator_list(module: LiteScriptModule):
75    r"""Return a set of root operator names (with overload name) that are used by any method in this mobile module."""
76    return torch._C._export_operator_list(module._c)
77
78
79def _get_model_bytecode_version(f_input) -> int:
80    r"""Take a file-like object to return an integer.
81
82    Args:
83        f_input: a file-like object (has to implement read, readline, tell, and seek),
84            or a string containing a file name
85
86    Returns:
87        version: An integer. If the integer is -1, the version is invalid. A warning
88            will show in the log.
89
90    Example:
91    .. testcode::
92
93        from torch.jit.mobile import _get_model_bytecode_version
94
95        # Get bytecode version from a saved file path
96        version = _get_model_bytecode_version("path/to/model.ptl")
97
98    """
99    if isinstance(f_input, (str, os.PathLike)):
100        if not os.path.exists(f_input):
101            raise ValueError(f"The provided filename {f_input} does not exist")
102        if os.path.isdir(f_input):
103            raise ValueError(f"The provided filename {f_input} is a directory")
104
105    if isinstance(f_input, (str, os.PathLike)):
106        return torch._C._get_model_bytecode_version(os.fspath(f_input))
107    else:
108        return torch._C._get_model_bytecode_version_from_buffer(f_input.read())
109
110
111def _get_mobile_model_contained_types(f_input) -> int:
112    r"""Take a file-like object and return a set of string, like ("int", "Optional").
113
114    Args:
115        f_input: a file-like object (has to implement read, readline, tell, and seek),
116            or a string containing a file name
117
118    Returns:
119        type_list: A set of string, like ("int", "Optional"). These are types used in bytecode.
120
121    Example:
122
123    .. testcode::
124
125        from torch.jit.mobile import _get_mobile_model_contained_types
126
127        # Get type list from a saved file path
128        type_list = _get_mobile_model_contained_types("path/to/model.ptl")
129
130    """
131    if isinstance(f_input, (str, os.PathLike)):
132        if not os.path.exists(f_input):
133            raise ValueError(f"The provided filename {f_input} does not exist")
134        if os.path.isdir(f_input):
135            raise ValueError(f"The provided filename {f_input} is a directory")
136
137    if isinstance(f_input, (str, os.PathLike)):
138        return torch._C._get_mobile_model_contained_types(os.fspath(f_input))
139    else:
140        return torch._C._get_mobile_model_contained_types_from_buffer(f_input.read())
141
142
143def _backport_for_mobile(f_input, f_output, to_version):
144    r"""Take a input string containing a file name (file-like object) and a new destination to return a boolean.
145
146    Args:
147        f_input: a file-like object (has to implement read, readline, tell, and seek),
148            or a string containing a file name
149        f_output: path to new model destination
150        to_version: the expected output model bytecode version
151    Returns:
152        success: A boolean. If backport success, return true, otherwise false
153    """
154    if isinstance(f_input, (str, os.PathLike)):
155        if not os.path.exists(f_input):
156            raise ValueError(f"The provided filename {f_input} does not exist")
157        if os.path.isdir(f_input):
158            raise ValueError(f"The provided filename {f_input} is a directory")
159
160    if (isinstance(f_input, (str, os.PathLike))) and (
161        isinstance(f_output, (str, os.PathLike))
162    ):
163        return torch._C._backport_for_mobile(
164            os.fspath(f_input), os.fspath(f_output), to_version
165        )
166    else:
167        return torch._C._backport_for_mobile_from_buffer(
168            f_input.read(), str(f_output), to_version
169        )
170
171
172def _backport_for_mobile_to_buffer(f_input, to_version):
173    r"""Take a string containing a file name (file-like object).
174
175    Args:
176        f_input: a file-like object (has to implement read, readline, tell, and seek),
177            or a string containing a file name
178
179    """
180    if isinstance(f_input, (str, os.PathLike)):
181        if not os.path.exists(f_input):
182            raise ValueError(f"The provided filename {f_input} does not exist")
183        if os.path.isdir(f_input):
184            raise ValueError(f"The provided filename {f_input} is a directory")
185
186    if isinstance(f_input, (str, os.PathLike)):
187        return torch._C._backport_for_mobile_to_buffer(os.fspath(f_input), to_version)
188    else:
189        return torch._C._backport_for_mobile_from_buffer_to_buffer(
190            f_input.read(), to_version
191        )
192
193
194def _get_model_ops_and_info(f_input):
195    r"""Retrieve the root (top level) operators of a model and their corresponding compatibility info.
196
197    These root operators can call other operators within them (traced ops), and
198    a root op can call many different traced ops depending on internal code paths in the root op.
199    These traced ops are not returned by this function. Those operators are abstracted into the
200    runtime as an implementation detail (and the traced ops themselves can also call other operators)
201    making retrieving them difficult and their value from this api negligible since they will differ
202    between which runtime version the model is run on. Because of this, there is a false positive this
203    api can't prevent in a compatibility usecase. All the root ops of a model are present in a
204    target runtime, but not all the traced ops are which prevents a model from being able to run.
205    Args:
206        f_input: a file-like object (has to implement read, readline, tell, and seek),
207            or a string containing a file name
208
209    Returns:
210        Operators and info: A Dictionary mapping strings (the qualified names of the root operators)
211        of the model to their OperatorInfo structs.
212
213    Example:
214
215    .. testcode::
216
217        from torch.jit.mobile import _get_model_ops_and_info
218
219        # Get bytecode version from a saved file path
220        ops_and_info = _get_model_ops_and_info("path/to/model.ptl")
221
222    """
223    if isinstance(f_input, (str, os.PathLike)):
224        if not os.path.exists(f_input):
225            raise ValueError(f"The provided filename {f_input} does not exist")
226        if os.path.isdir(f_input):
227            raise ValueError(f"The provided filename {f_input} is a directory")
228
229    if isinstance(f_input, (str, os.PathLike)):
230        return torch._C._get_model_ops_and_info(os.fspath(f_input))
231    else:
232        return torch._C._get_model_ops_and_info(f_input.read())
233