1# mypy: allow-untyped-defs
2import argparse
3import copy
4import os
5import time
6import zipfile
7from typing import Dict, List
8from zipfile import ZipFile
9
10import pandas as pd  # type: ignore[import]
11from dlrm_utils import get_dlrm_model, get_valid_name  # type: ignore[import]
12
13import torch
14from torch.ao.pruning._experimental.data_sparsifier import DataNormSparsifier
15
16
17def create_attach_sparsifier(model, **sparse_config):
18    """Create a DataNormSparsifier and the attach it to the model embedding layers
19
20    Args:
21        model (nn.Module)
22            layer of the model that needs to be attached to the sparsifier
23        sparse_config (Dict)
24            Config to the DataNormSparsifier. Should contain the following keys:
25                - sparse_block_shape
26                - norm
27                - sparsity_level
28    """
29    data_norm_sparsifier = DataNormSparsifier(**sparse_config)
30    for name, parameter in model.named_parameters():
31        if "emb_l" in name:
32            valid_name = get_valid_name(name)
33            data_norm_sparsifier.add_data(name=valid_name, data=parameter)
34    return data_norm_sparsifier
35
36
37def save_model_states(
38    state_dict,
39    sparsified_model_dump_path,
40    save_file_name,
41    sparse_block_shape,
42    norm,
43    zip=True,
44):
45    """Dumps the state_dict() of the model.
46
47    Args:
48        state_dict (Dict)
49            The state_dict() as dumped by dlrm_s_pytorch.py. Only the model state will be extracted
50            from this dictionary. This corresponds to the 'state_dict' key in the state_dict dictionary.
51            >>> model_state = state_dict['state_dict']
52        save_file_name (str)
53            The filename (not path) when saving the model state dictionary
54        sparse_block_shape (Tuple)
55            The block shape corresponding to the data norm sparsifier. **Used for creating save directory**
56        norm (str)
57            type of norm (L1, L2) for the datanorm sparsifier. **Used for creating save directory**
58        zip (bool)
59            if True, the file is zip-compressed.
60    """
61    folder_name = os.path.join(sparsified_model_dump_path, str(norm))
62
63    # save model only states
64    folder_str = f"config_{sparse_block_shape}"
65    model_state = state_dict["state_dict"]
66    model_state_path = os.path.join(folder_name, folder_str, save_file_name)
67
68    os.makedirs(os.path.dirname(model_state_path), exist_ok=True)
69    torch.save(model_state, model_state_path)
70
71    if zip:
72        zip_path = model_state_path.replace(".ckpt", ".zip")
73        with ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zip:
74            zip.write(model_state_path, save_file_name)
75        os.remove(model_state_path)  # store it as zip, remove uncompressed
76        model_state_path = zip_path
77
78    model_state_path = os.path.abspath(model_state_path)
79    file_size = os.path.getsize(model_state_path)
80    file_size = file_size >> 20  # size in mb
81    return model_state_path, file_size
82
83
84def sparsify_model(path_to_model, sparsified_model_dump_path):
85    """Sparsifies the embedding layers of the dlrm model for different sparsity levels, norms and block shapes
86    using the DataNormSparsifier.
87    The function tracks the step time of the sparsifier and the size of the compressed checkpoint and collates
88    it into a csv.
89
90    Note::
91        This function dumps a csv sparse_model_metadata.csv in the current directory.
92
93    Args:
94        path_to_model (str)
95            path to the trained criteo model ckpt file
96        sparsity_levels (List of float)
97            list of sparsity levels to be sparsified on
98        norms (List of str)
99            list of norms to be sparsified on
100        sparse_block_shapes (List of tuples)
101            List of sparse block shapes to be sparsified on
102    """
103    sparsity_levels = [sl / 10 for sl in range(0, 10)]
104    sparsity_levels += [0.91, 0.92, 0.93, 0.94, 0.95, 0.96, 0.97, 0.98, 0.99, 1.0]
105
106    norms = ["L1", "L2"]
107    sparse_block_shapes = [(1, 1), (1, 4)]
108
109    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
110
111    print("Running for sparsity levels - ", sparsity_levels)
112    print("Running for sparse block shapes - ", sparse_block_shapes)
113    print("Running for norms - ", norms)
114
115    orig_model = get_dlrm_model()
116    saved_state = torch.load(path_to_model, map_location=device)
117    orig_model.load_state_dict(saved_state["state_dict"])
118
119    orig_model = orig_model.to(device)
120    step_time_dict = {}
121
122    stat_dict: Dict[str, List] = {
123        "norm": [],
124        "sparse_block_shape": [],
125        "sparsity_level": [],
126        "step_time_sec": [],
127        "zip_file_size": [],
128        "path": [],
129    }
130    for norm in norms:
131        for sbs in sparse_block_shapes:
132            if norm == "L2" and sbs == (1, 1):
133                continue
134            for sl in sparsity_levels:
135                model = copy.deepcopy(orig_model)
136                sparsifier = create_attach_sparsifier(
137                    model, sparse_block_shape=sbs, norm=norm, sparsity_level=sl
138                )
139
140                t1 = time.time()
141                sparsifier.step()
142                t2 = time.time()
143
144                step_time = t2 - t1
145                norm_sl = f"{norm}_{sbs}_{sl}"
146                print(f"Step Time for {norm_sl}=: {step_time} s")
147
148                step_time_dict[norm_sl] = step_time
149
150                sparsifier.squash_mask()
151
152                saved_state["state_dict"] = model.state_dict()
153                file_name = f"criteo_model_norm={norm}_sl={sl}.ckpt"
154                state_path, file_size = save_model_states(
155                    saved_state, sparsified_model_dump_path, file_name, sbs, norm=norm
156                )
157
158                stat_dict["norm"].append(norm)
159                stat_dict["sparse_block_shape"].append(sbs)
160                stat_dict["sparsity_level"].append(sl)
161                stat_dict["step_time_sec"].append(step_time)
162                stat_dict["zip_file_size"].append(file_size)
163                stat_dict["path"].append(state_path)
164
165    df = pd.DataFrame(stat_dict)
166    filename = "sparse_model_metadata.csv"
167    df.to_csv(filename, index=False)
168
169    print(f"Saved sparsified metadata file in {filename}")
170
171
172if __name__ == "__main__":
173    parser = argparse.ArgumentParser()
174    parser.add_argument("--model-path", "--model_path", type=str)
175    parser.add_argument(
176        "--sparsified-model-dump-path", "--sparsified_model_dump_path", type=str
177    )
178    args = parser.parse_args()
179
180    sparsify_model(args.model_path, args.sparsified_model_dump_path)
181