1# mypy: allow-untyped-defs 2from __future__ import annotations 3 4import io 5import logging 6import os 7from typing import TYPE_CHECKING 8 9import torch 10from torch.onnx import _type_utils as jit_type_utils 11 12 13if TYPE_CHECKING: 14 import onnx 15 16log = logging.getLogger(__name__) 17 18 19def _create_tensor_proto_with_external_data( 20 tensor: torch.Tensor, 21 name: str, 22 location: str, 23 basepath: str, 24 dtype_override: onnx.TypeProto | None = None, # type: ignore[name-defined] 25) -> onnx.TensorProto: # type: ignore[name-defined] 26 """Create a TensorProto with external data from a PyTorch tensor. 27 The external data is saved to os.path.join(basepath, location). 28 29 Args: 30 tensor: Tensor to be saved. 31 name: Name of the tensor (i.e., initializer name in ONNX graph). 32 location: Relative location of the external data file 33 (e.g., "/tmp/initializers/weight_0" when model is "/tmp/model_name.onnx"). 34 basepath: Base path of the external data file (e.g., "/tmp/external_data" while model must be in "/tmp"). 35 36 37 Reference for ONNX's external data format: 38 How to load? 39 https://github.com/onnx/onnx/blob/5dac81ac0707bdf88f56c35c0a5e8855d3534673/onnx/external_data_helper.py#L187 40 How to save? 41 https://github.com/onnx/onnx/blob/5dac81ac0707bdf88f56c35c0a5e8855d3534673/onnx/external_data_helper.py#L43 42 How to set ONNX fields? 43 https://github.com/onnx/onnx/blob/5dac81ac0707bdf88f56c35c0a5e8855d3534673/onnx/external_data_helper.py#L88 44 """ 45 # FIXME: Avoid importing onnx into torch.onnx. 46 import onnx 47 48 scalar_type = ( 49 jit_type_utils.JitScalarType.from_onnx_type( 50 dtype_override.tensor_type.elem_type 51 ) 52 if dtype_override is not None 53 else jit_type_utils.JitScalarType.from_dtype(tensor.dtype) 54 ) 55 56 # Checkpoints can be stored with a different dtype as the model expects because 57 # the user script can explicitly cast the original type to something or maybe 58 # PyTorch's type promotion might do it 59 if dtype_override is not None and scalar_type.dtype() != tensor.dtype: 60 tensor = tensor.to(scalar_type.dtype()) 61 62 tensor_proto = onnx.TensorProto() # type: ignore[attr-defined] 63 tensor_proto.name = name 64 tensor_proto.data_type = scalar_type.onnx_type() # type: ignore[assignment] 65 66 tensor_proto.dims.extend(tensor.shape) 67 tensor_proto.data_location = onnx.TensorProto.EXTERNAL # type: ignore[attr-defined] 68 69 # Settings for saving one tensor per file. 70 # Offset is zero because there is no other tensor in the same file. 71 key_value_pairs = { 72 "location": location, 73 "offset": 0, 74 "length": tensor.untyped_storage().nbytes(), 75 } 76 for k, v in key_value_pairs.items(): 77 entry = tensor_proto.external_data.add() 78 entry.key = k 79 entry.value = str(v) 80 81 # Actual path to write content of tensor. 82 external_data_file_path = os.path.join(basepath, location) 83 if os.path.exists(external_data_file_path): 84 os.remove(external_data_file_path) 85 86 # Create external data's folder if not exists. 87 external_data_dir_path = os.path.dirname(external_data_file_path) 88 if not os.path.exists(external_data_dir_path): 89 # if the demo_folder directory is not present 90 # then create it. 91 os.makedirs(external_data_dir_path) 92 93 # Create a fresh file. 94 with open(external_data_file_path, "xb") as data_file: 95 # No need to call "seek" because offset is 0. 96 # data_file.seek(0) 97 # Write tensor content to the file. 98 data_file.write(tensor.numpy(force=True).tobytes()) 99 100 return tensor_proto 101 102 103def _convert_safetensors_to_torch_format(safetensors_file): 104 # It this function is called, safetensors is guaranteed to exist 105 # because the HF model with safetensors was already loaded and exported to ONNX 106 from safetensors import safe_open # type: ignore[import-not-found] 107 108 tensors = {} 109 with safe_open(safetensors_file, framework="pt", device="cpu") as f: # type: ignore[attr-defined] 110 for k in f.keys(): 111 tensors[k] = f.get_tensor(k).cpu() 112 return tensors 113 114 115# TODO: generalize to allow more checkpoints formats (torch or gguf) 116def save_model_with_external_data( 117 basepath: str, 118 model_location: str, 119 initializer_location: str, 120 torch_state_dicts: tuple[dict | str | io.BytesIO, ...], 121 onnx_model: onnx.ModelProto, # type: ignore[name-defined] 122 rename_initializer: bool = False, 123) -> None: 124 """Load PyTorch tensors from files and add to "onnx_model" as external initializers. 125 126 Output files: 127 ONNX model file path: 128 ONNX initializer folder: os.path.join(basepath, initializer_location) 129 130 After running this function, you can do 131 ort_sess = onnxruntime.InferenceSession(os.path.join(basepath, model_location)) 132 to execute the model. 133 134 Arguments: 135 basepath: Base path of the ONNX external data file (e.g., "/path/to/large_model/"). 136 model_location: Relative location of the ONNX model file. 137 E.g., "model.onnx" so that the model file is saved to 138 "<basepath>/model.onnx". 139 initializer_location: Relative location of the ONNX initializer folder. 140 E.g., "initializers" so that the initializers are saved to 141 "<basepath>/initializers/". 142 Note: When initializers are >2GB, must be the same as `model_location`. 143 torch_state_dicts: Dictionaries or files which contain PyTorch tensors to be saved 144 as ONNX initializers. For non-dict arguments, `torch.load` will be used to load them from file-like objects. 145 onnx_model: ONNX model to be saved with external initializers. 146 If an input name matches a tensor loaded from "torch_state_dicts", 147 the tensor will be saved as that input's external initializer. 148 rename_initializer: Replaces "." by "_" for all ONNX initializer names. 149 Not needed by the official torch.onnx.dynamo_export. This is a hack 150 for supporting `FXSymbolicTracer` tracer with fake tensor mode. 151 In short, `FXSymbolicTracer` lifts FX parameters (self.linear_weight) 152 as inputs (`def forward(self, linear_weight)`) and therefore, `.` cannot be used. 153 """ 154 # FIXME: Avoid importing onnx into torch.onnx. 155 import onnx 156 157 initializers_to_be_deleted = {} # Using dict because it is **ordered** 158 existing_initializers = { 159 k.name: idx for idx, k in enumerate(onnx_model.graph.initializer) 160 } 161 onnx_input_names = {input.name for input in onnx_model.graph.input} 162 for el in torch_state_dicts: 163 if isinstance(el, dict): 164 # Useful for when state_dict is loaded with torch.load(..., mmap=True, map_location="cpu") by the user 165 # Using torch.save wouldn't leverage mmap, leading to higher memory usage 166 state_dict = el 167 else: 168 if isinstance(el, str) and el.endswith(".safetensors"): 169 state_dict = _convert_safetensors_to_torch_format(el) 170 else: 171 try: 172 # Loads checkpoint using memory-map on CPU to support really large models 173 # The underlying torch.UntypedStorage is memory mapped, so state_dict is lazy loaded 174 state_dict = torch.load(el, map_location="cpu", mmap=True) 175 except (RuntimeError, ValueError) as e: 176 if "mmap can only be used with files saved with" in str( 177 e 178 ) or isinstance(el, io.BytesIO): 179 log.warning( 180 "Failed to load the checkpoint with memory-map enabled, retrying without memory-map." 181 "Consider updating the checkpoint with mmap by using torch.save() on PyTorch version >= 1.6." 182 ) 183 if isinstance(el, io.BytesIO): 184 el.seek(0) # torch.load from `try:` has read the file. 185 state_dict = torch.load(el, map_location="cpu") 186 else: 187 raise e 188 189 for name, tensor in state_dict.items(): 190 if rename_initializer: 191 # Basically, "transformer.attention.self.query.weight" is mapped 192 # to "transformer_attention_self_query_weight" for mimicking the 193 # name-modifying code in FX-to-ONNX exporter. 194 # See function _replace_get_attr_with_placeholder for details. 195 name = name.replace(".", "_") 196 197 # This block tries to match the onnx initializer name with torch parameter/buffer 198 # e.g. A pytorch buffer 'transformer.h.0.attn.bias' can be named 'h.0.attn.bias' in a ONNX initializer 199 # For each PyTorch tensor name loaded by torch.load, 200 # 1. Search its best match in ONNX model. E.g., the match of 201 # "transformer_attention_weight" could be "attention_weight". 202 # 2. Set "tensor" as the initializer of the matched ONNX input. 203 # E.g., "tensor" is stored as the initializer of "attention_weight". 204 # Step 1 is required because sometimes, tensor names are stored with prefix the dictionary 205 # loaded by torch.load. 206 if name in onnx_input_names: 207 # Same input name shouldn't be matched again 208 onnx_input_names.remove(name) 209 else: 210 for onnx_input_name in onnx_input_names: 211 if onnx_input_name.endswith(name) or name.endswith(onnx_input_name): 212 # Find a match. Change name to the matched ONNX input name, so that we 213 # create initializer with the right ONNX name. 214 name = onnx_input_name 215 onnx_input_names.remove(onnx_input_name) 216 break 217 218 relative_tensor_file_path = os.path.join(initializer_location, name) 219 # Create one file per tensor. 220 # tensor_proto.raw_data is stored to external file at 221 # os.path.join(basepath, relative_tensor_file_path). 222 model_input_types = {k.name: k.type for k in onnx_model.graph.input} 223 224 # Mark for deletion - a replacement will be appended next 225 if name in existing_initializers: 226 initializers_to_be_deleted[existing_initializers[name]] = name 227 tensor_proto = _create_tensor_proto_with_external_data( 228 tensor, 229 name, 230 relative_tensor_file_path, 231 basepath, 232 model_input_types.pop(name, None), 233 ) 234 # Add the tensor_proto to the ONNX model as an initializer with external data. 235 onnx_model.graph.initializer.append(tensor_proto) 236 # Remove old duplicated initializers, if any. delete in desc order to not invalidate deletion indices 237 initializers_to_be_deleted = dict( 238 sorted(initializers_to_be_deleted.items(), reverse=True) 239 ) 240 for idx in initializers_to_be_deleted.keys(): 241 del onnx_model.graph.initializer[idx] 242 243 # model_location should be a pure file name such as "file_name.onnx", not "folder/file_name.onnx". 244 onnx.save(onnx_model, os.path.join(basepath, model_location)) # type: ignore[attr-defined] 245