xref: /aosp_15_r20/external/pytorch/torch/onnx/_internal/fx/serialization.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from __future__ import annotations
3
4import io
5import logging
6import os
7from typing import TYPE_CHECKING
8
9import torch
10from torch.onnx import _type_utils as jit_type_utils
11
12
13if TYPE_CHECKING:
14    import onnx
15
16log = logging.getLogger(__name__)
17
18
19def _create_tensor_proto_with_external_data(
20    tensor: torch.Tensor,
21    name: str,
22    location: str,
23    basepath: str,
24    dtype_override: onnx.TypeProto | None = None,  # type: ignore[name-defined]
25) -> onnx.TensorProto:  # type: ignore[name-defined]
26    """Create a TensorProto with external data from a PyTorch tensor.
27    The external data is saved to os.path.join(basepath, location).
28
29    Args:
30        tensor: Tensor to be saved.
31        name: Name of the tensor (i.e., initializer name in ONNX graph).
32        location: Relative location of the external data file
33            (e.g., "/tmp/initializers/weight_0" when model is "/tmp/model_name.onnx").
34        basepath: Base path of the external data file (e.g., "/tmp/external_data" while model must be in "/tmp").
35
36
37    Reference for ONNX's external data format:
38        How to load?
39        https://github.com/onnx/onnx/blob/5dac81ac0707bdf88f56c35c0a5e8855d3534673/onnx/external_data_helper.py#L187
40        How to save?
41        https://github.com/onnx/onnx/blob/5dac81ac0707bdf88f56c35c0a5e8855d3534673/onnx/external_data_helper.py#L43
42        How to set ONNX fields?
43        https://github.com/onnx/onnx/blob/5dac81ac0707bdf88f56c35c0a5e8855d3534673/onnx/external_data_helper.py#L88
44    """
45    # FIXME: Avoid importing onnx into torch.onnx.
46    import onnx
47
48    scalar_type = (
49        jit_type_utils.JitScalarType.from_onnx_type(
50            dtype_override.tensor_type.elem_type
51        )
52        if dtype_override is not None
53        else jit_type_utils.JitScalarType.from_dtype(tensor.dtype)
54    )
55
56    # Checkpoints can be stored with a different dtype as the model expects because
57    # the user script can explicitly cast the original type to something or maybe
58    # PyTorch's type promotion might do it
59    if dtype_override is not None and scalar_type.dtype() != tensor.dtype:
60        tensor = tensor.to(scalar_type.dtype())
61
62    tensor_proto = onnx.TensorProto()  # type: ignore[attr-defined]
63    tensor_proto.name = name
64    tensor_proto.data_type = scalar_type.onnx_type()  # type: ignore[assignment]
65
66    tensor_proto.dims.extend(tensor.shape)
67    tensor_proto.data_location = onnx.TensorProto.EXTERNAL  # type: ignore[attr-defined]
68
69    # Settings for saving one tensor per file.
70    # Offset is zero because there is no other tensor in the same file.
71    key_value_pairs = {
72        "location": location,
73        "offset": 0,
74        "length": tensor.untyped_storage().nbytes(),
75    }
76    for k, v in key_value_pairs.items():
77        entry = tensor_proto.external_data.add()
78        entry.key = k
79        entry.value = str(v)
80
81    # Actual path to write content of tensor.
82    external_data_file_path = os.path.join(basepath, location)
83    if os.path.exists(external_data_file_path):
84        os.remove(external_data_file_path)
85
86    # Create external data's folder if not exists.
87    external_data_dir_path = os.path.dirname(external_data_file_path)
88    if not os.path.exists(external_data_dir_path):
89        # if the demo_folder directory is not present
90        # then create it.
91        os.makedirs(external_data_dir_path)
92
93    # Create a fresh file.
94    with open(external_data_file_path, "xb") as data_file:
95        # No need to call "seek" because offset is 0.
96        # data_file.seek(0)
97        # Write tensor content to the file.
98        data_file.write(tensor.numpy(force=True).tobytes())
99
100    return tensor_proto
101
102
103def _convert_safetensors_to_torch_format(safetensors_file):
104    # It this function is called, safetensors is guaranteed to exist
105    # because the HF model with safetensors was already loaded and exported to ONNX
106    from safetensors import safe_open  # type: ignore[import-not-found]
107
108    tensors = {}
109    with safe_open(safetensors_file, framework="pt", device="cpu") as f:  # type: ignore[attr-defined]
110        for k in f.keys():
111            tensors[k] = f.get_tensor(k).cpu()
112    return tensors
113
114
115# TODO: generalize to allow more checkpoints formats (torch or gguf)
116def save_model_with_external_data(
117    basepath: str,
118    model_location: str,
119    initializer_location: str,
120    torch_state_dicts: tuple[dict | str | io.BytesIO, ...],
121    onnx_model: onnx.ModelProto,  # type: ignore[name-defined]
122    rename_initializer: bool = False,
123) -> None:
124    """Load PyTorch tensors from files and add to "onnx_model" as external initializers.
125
126    Output files:
127        ONNX model file path:
128        ONNX initializer folder: os.path.join(basepath, initializer_location)
129
130    After running this function, you can do
131        ort_sess = onnxruntime.InferenceSession(os.path.join(basepath, model_location))
132    to execute the model.
133
134    Arguments:
135        basepath: Base path of the ONNX external data file (e.g., "/path/to/large_model/").
136        model_location: Relative location of the ONNX model file.
137            E.g., "model.onnx" so that the model file is saved to
138            "<basepath>/model.onnx".
139        initializer_location: Relative location of the ONNX initializer folder.
140            E.g., "initializers" so that the initializers are saved to
141            "<basepath>/initializers/".
142            Note: When initializers are >2GB, must be the same as `model_location`.
143        torch_state_dicts: Dictionaries or files which contain PyTorch tensors to be saved
144            as ONNX initializers. For non-dict arguments, `torch.load` will be used to load them from file-like objects.
145        onnx_model: ONNX model to be saved with external initializers.
146            If an input name matches a tensor loaded from "torch_state_dicts",
147            the tensor will be saved as that input's external initializer.
148        rename_initializer: Replaces "." by "_" for all ONNX initializer names.
149            Not needed by the official torch.onnx.dynamo_export. This is a hack
150            for supporting `FXSymbolicTracer` tracer with fake tensor mode.
151            In short, `FXSymbolicTracer` lifts FX parameters (self.linear_weight)
152            as inputs (`def forward(self, linear_weight)`) and therefore, `.` cannot be used.
153    """
154    # FIXME: Avoid importing onnx into torch.onnx.
155    import onnx
156
157    initializers_to_be_deleted = {}  # Using dict because it is **ordered**
158    existing_initializers = {
159        k.name: idx for idx, k in enumerate(onnx_model.graph.initializer)
160    }
161    onnx_input_names = {input.name for input in onnx_model.graph.input}
162    for el in torch_state_dicts:
163        if isinstance(el, dict):
164            # Useful for when state_dict is loaded with torch.load(..., mmap=True, map_location="cpu") by the user
165            # Using torch.save wouldn't leverage mmap, leading to higher memory usage
166            state_dict = el
167        else:
168            if isinstance(el, str) and el.endswith(".safetensors"):
169                state_dict = _convert_safetensors_to_torch_format(el)
170            else:
171                try:
172                    # Loads checkpoint using memory-map on CPU to support really large models
173                    # The underlying torch.UntypedStorage is memory mapped, so state_dict is lazy loaded
174                    state_dict = torch.load(el, map_location="cpu", mmap=True)
175                except (RuntimeError, ValueError) as e:
176                    if "mmap can only be used with files saved with" in str(
177                        e
178                    ) or isinstance(el, io.BytesIO):
179                        log.warning(
180                            "Failed to load the checkpoint with memory-map enabled, retrying without memory-map."
181                            "Consider updating the checkpoint with mmap by using torch.save() on PyTorch version >= 1.6."
182                        )
183                        if isinstance(el, io.BytesIO):
184                            el.seek(0)  # torch.load from `try:` has read the file.
185                        state_dict = torch.load(el, map_location="cpu")
186                    else:
187                        raise e
188
189        for name, tensor in state_dict.items():
190            if rename_initializer:
191                # Basically, "transformer.attention.self.query.weight" is mapped
192                # to "transformer_attention_self_query_weight" for mimicking the
193                # name-modifying code in FX-to-ONNX exporter.
194                # See function _replace_get_attr_with_placeholder for details.
195                name = name.replace(".", "_")
196
197            # This block tries to match the onnx initializer name with torch parameter/buffer
198            #  e.g. A pytorch buffer 'transformer.h.0.attn.bias' can be named 'h.0.attn.bias' in a ONNX initializer
199            # For each PyTorch tensor name loaded by torch.load,
200            #  1.  Search its best match in ONNX model. E.g., the match of
201            #       "transformer_attention_weight" could be "attention_weight".
202            #  2.  Set "tensor" as the initializer of the matched ONNX input.
203            #      E.g., "tensor" is stored as the initializer of "attention_weight".
204            # Step 1 is required because sometimes, tensor names are stored with prefix the dictionary
205            # loaded by torch.load.
206            if name in onnx_input_names:
207                # Same input name shouldn't be matched again
208                onnx_input_names.remove(name)
209            else:
210                for onnx_input_name in onnx_input_names:
211                    if onnx_input_name.endswith(name) or name.endswith(onnx_input_name):
212                        # Find a match. Change name to the matched ONNX input name, so that we
213                        # create initializer with the right ONNX name.
214                        name = onnx_input_name
215                        onnx_input_names.remove(onnx_input_name)
216                        break
217
218            relative_tensor_file_path = os.path.join(initializer_location, name)
219            # Create one file per tensor.
220            # tensor_proto.raw_data is stored to external file at
221            # os.path.join(basepath, relative_tensor_file_path).
222            model_input_types = {k.name: k.type for k in onnx_model.graph.input}
223
224            # Mark for deletion - a replacement will be appended next
225            if name in existing_initializers:
226                initializers_to_be_deleted[existing_initializers[name]] = name
227            tensor_proto = _create_tensor_proto_with_external_data(
228                tensor,
229                name,
230                relative_tensor_file_path,
231                basepath,
232                model_input_types.pop(name, None),
233            )
234            # Add the tensor_proto to the ONNX model as an initializer with external data.
235            onnx_model.graph.initializer.append(tensor_proto)
236    # Remove old duplicated initializers, if any. delete in desc order to not invalidate deletion indices
237    initializers_to_be_deleted = dict(
238        sorted(initializers_to_be_deleted.items(), reverse=True)
239    )
240    for idx in initializers_to_be_deleted.keys():
241        del onnx_model.graph.initializer[idx]
242
243    # model_location should be a pure file name such as "file_name.onnx", not "folder/file_name.onnx".
244    onnx.save(onnx_model, os.path.join(basepath, model_location))  # type: ignore[attr-defined]
245