1# mypy: allow-untyped-defs 2import argparse 3import copy 4import os 5import time 6import zipfile 7from typing import Dict, List 8from zipfile import ZipFile 9 10import pandas as pd # type: ignore[import] 11from dlrm_utils import get_dlrm_model, get_valid_name # type: ignore[import] 12 13import torch 14from torch.ao.pruning._experimental.data_sparsifier import DataNormSparsifier 15 16 17def create_attach_sparsifier(model, **sparse_config): 18 """Create a DataNormSparsifier and the attach it to the model embedding layers 19 20 Args: 21 model (nn.Module) 22 layer of the model that needs to be attached to the sparsifier 23 sparse_config (Dict) 24 Config to the DataNormSparsifier. Should contain the following keys: 25 - sparse_block_shape 26 - norm 27 - sparsity_level 28 """ 29 data_norm_sparsifier = DataNormSparsifier(**sparse_config) 30 for name, parameter in model.named_parameters(): 31 if "emb_l" in name: 32 valid_name = get_valid_name(name) 33 data_norm_sparsifier.add_data(name=valid_name, data=parameter) 34 return data_norm_sparsifier 35 36 37def save_model_states( 38 state_dict, 39 sparsified_model_dump_path, 40 save_file_name, 41 sparse_block_shape, 42 norm, 43 zip=True, 44): 45 """Dumps the state_dict() of the model. 46 47 Args: 48 state_dict (Dict) 49 The state_dict() as dumped by dlrm_s_pytorch.py. Only the model state will be extracted 50 from this dictionary. This corresponds to the 'state_dict' key in the state_dict dictionary. 51 >>> model_state = state_dict['state_dict'] 52 save_file_name (str) 53 The filename (not path) when saving the model state dictionary 54 sparse_block_shape (Tuple) 55 The block shape corresponding to the data norm sparsifier. **Used for creating save directory** 56 norm (str) 57 type of norm (L1, L2) for the datanorm sparsifier. **Used for creating save directory** 58 zip (bool) 59 if True, the file is zip-compressed. 60 """ 61 folder_name = os.path.join(sparsified_model_dump_path, str(norm)) 62 63 # save model only states 64 folder_str = f"config_{sparse_block_shape}" 65 model_state = state_dict["state_dict"] 66 model_state_path = os.path.join(folder_name, folder_str, save_file_name) 67 68 os.makedirs(os.path.dirname(model_state_path), exist_ok=True) 69 torch.save(model_state, model_state_path) 70 71 if zip: 72 zip_path = model_state_path.replace(".ckpt", ".zip") 73 with ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zip: 74 zip.write(model_state_path, save_file_name) 75 os.remove(model_state_path) # store it as zip, remove uncompressed 76 model_state_path = zip_path 77 78 model_state_path = os.path.abspath(model_state_path) 79 file_size = os.path.getsize(model_state_path) 80 file_size = file_size >> 20 # size in mb 81 return model_state_path, file_size 82 83 84def sparsify_model(path_to_model, sparsified_model_dump_path): 85 """Sparsifies the embedding layers of the dlrm model for different sparsity levels, norms and block shapes 86 using the DataNormSparsifier. 87 The function tracks the step time of the sparsifier and the size of the compressed checkpoint and collates 88 it into a csv. 89 90 Note:: 91 This function dumps a csv sparse_model_metadata.csv in the current directory. 92 93 Args: 94 path_to_model (str) 95 path to the trained criteo model ckpt file 96 sparsity_levels (List of float) 97 list of sparsity levels to be sparsified on 98 norms (List of str) 99 list of norms to be sparsified on 100 sparse_block_shapes (List of tuples) 101 List of sparse block shapes to be sparsified on 102 """ 103 sparsity_levels = [sl / 10 for sl in range(0, 10)] 104 sparsity_levels += [0.91, 0.92, 0.93, 0.94, 0.95, 0.96, 0.97, 0.98, 0.99, 1.0] 105 106 norms = ["L1", "L2"] 107 sparse_block_shapes = [(1, 1), (1, 4)] 108 109 device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 110 111 print("Running for sparsity levels - ", sparsity_levels) 112 print("Running for sparse block shapes - ", sparse_block_shapes) 113 print("Running for norms - ", norms) 114 115 orig_model = get_dlrm_model() 116 saved_state = torch.load(path_to_model, map_location=device) 117 orig_model.load_state_dict(saved_state["state_dict"]) 118 119 orig_model = orig_model.to(device) 120 step_time_dict = {} 121 122 stat_dict: Dict[str, List] = { 123 "norm": [], 124 "sparse_block_shape": [], 125 "sparsity_level": [], 126 "step_time_sec": [], 127 "zip_file_size": [], 128 "path": [], 129 } 130 for norm in norms: 131 for sbs in sparse_block_shapes: 132 if norm == "L2" and sbs == (1, 1): 133 continue 134 for sl in sparsity_levels: 135 model = copy.deepcopy(orig_model) 136 sparsifier = create_attach_sparsifier( 137 model, sparse_block_shape=sbs, norm=norm, sparsity_level=sl 138 ) 139 140 t1 = time.time() 141 sparsifier.step() 142 t2 = time.time() 143 144 step_time = t2 - t1 145 norm_sl = f"{norm}_{sbs}_{sl}" 146 print(f"Step Time for {norm_sl}=: {step_time} s") 147 148 step_time_dict[norm_sl] = step_time 149 150 sparsifier.squash_mask() 151 152 saved_state["state_dict"] = model.state_dict() 153 file_name = f"criteo_model_norm={norm}_sl={sl}.ckpt" 154 state_path, file_size = save_model_states( 155 saved_state, sparsified_model_dump_path, file_name, sbs, norm=norm 156 ) 157 158 stat_dict["norm"].append(norm) 159 stat_dict["sparse_block_shape"].append(sbs) 160 stat_dict["sparsity_level"].append(sl) 161 stat_dict["step_time_sec"].append(step_time) 162 stat_dict["zip_file_size"].append(file_size) 163 stat_dict["path"].append(state_path) 164 165 df = pd.DataFrame(stat_dict) 166 filename = "sparse_model_metadata.csv" 167 df.to_csv(filename, index=False) 168 169 print(f"Saved sparsified metadata file in {filename}") 170 171 172if __name__ == "__main__": 173 parser = argparse.ArgumentParser() 174 parser.add_argument("--model-path", "--model_path", type=str) 175 parser.add_argument( 176 "--sparsified-model-dump-path", "--sparsified_model_dump_path", type=str 177 ) 178 args = parser.parse_args() 179 180 sparsify_model(args.model_path, args.sparsified_model_dump_path) 181