xref: /aosp_15_r20/external/pytorch/torch/ao/pruning/_experimental/data_sparsifier/benchmarks/dlrm_utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import os
3import zipfile
4
5import numpy as np  # type: ignore[import]
6from dlrm_data_pytorch import (  # type: ignore[import]
7    collate_wrapper_criteo_offset,
8    CriteoDataset,
9)
10from dlrm_s_pytorch import DLRM_Net  # type: ignore[import]
11
12import torch
13
14
15class SparseDLRM(DLRM_Net):
16    """The SparseDLRM model is a wrapper around the DLRM_Net model that tries
17    to use torch.sparse tensors for the features obtained after the ```interact_features()```
18    call. The idea is to do a simple torch.mm() with the weight matrix of the first linear
19    layer of the top layer.
20    """
21
22    def __init__(self, **args):
23        super().__init__(**args)
24
25    def forward(self, dense_x, lS_o, lS_i):
26        x = self.apply_mlp(dense_x, self.bot_l)  # dense features
27        ly = self.apply_emb(lS_o, lS_i, self.emb_l, self.v_W_l)  # apply embedding bag
28        z = self.interact_features(x, ly)
29
30        z = z.to_sparse_coo()
31        z = torch.mm(z, self.top_l[0].weight.T).add(self.top_l[0].bias)
32        for layer in self.top_l[1:]:
33            z = layer(z)
34
35        return z
36
37
38def get_valid_name(name):
39    """Replaces '.' with '_' as names with '.' are invalid in data sparsifier"""
40    return name.replace(".", "_")
41
42
43def get_dlrm_model(sparse_dlrm=False):
44    """Obtain dlrm model. The configs specified are based on the script in
45    bench/dlrm_s_criteo_kaggle.sh. The same config is used to train the model
46    for benchmarking on data sparsifier.
47    """
48    dlrm_model_config = {
49        "m_spa": 16,
50        "ln_emb": np.array(
51            [
52                1460,
53                583,
54                10131227,
55                2202608,
56                305,
57                24,
58                12517,
59                633,
60                3,
61                93145,
62                5683,
63                8351593,
64                3194,
65                27,
66                14992,
67                5461306,
68                10,
69                5652,
70                2173,
71                4,
72                7046547,
73                18,
74                15,
75                286181,
76                105,
77                142572,
78            ],
79            dtype=np.int32,
80        ),
81        "ln_bot": np.array([13, 512, 256, 64, 16]),
82        "ln_top": np.array([367, 512, 256, 1]),
83        "arch_interaction_op": "dot",
84        "arch_interaction_itself": False,
85        "sigmoid_bot": -1,
86        "sigmoid_top": 2,
87        "sync_dense_params": True,
88        "loss_threshold": 0.0,
89        "ndevices": 1,
90        "qr_flag": False,
91        "qr_operation": "mult",
92        "qr_collisions": 4,
93        "qr_threshold": 200,
94        "md_flag": False,
95        "md_threshold": 200,
96        "weighted_pooling": None,
97        "loss_function": "bce",
98    }
99    if sparse_dlrm:
100        dlrm_model = SparseDLRM(**dlrm_model_config)
101    else:
102        dlrm_model = DLRM_Net(**dlrm_model_config)
103    return dlrm_model
104
105
106def dlrm_wrap(X, lS_o, lS_i, device, ndevices=1):
107    """Rewritten simpler version of ```dlrm_wrap()``` found in dlrm_s_pytorch.py.
108    This function simply moves the input tensors into the device and without the forward pass
109    """
110    if ndevices == 1:
111        lS_i = (
112            [S_i.to(device) for S_i in lS_i]
113            if isinstance(lS_i, list)
114            else lS_i.to(device)
115        )
116        lS_o = (
117            [S_o.to(device) for S_o in lS_o]
118            if isinstance(lS_o, list)
119            else lS_o.to(device)
120        )
121    return X.to(device), lS_o, lS_i
122
123
124def make_test_data_loader(raw_data_file_path, processed_data_file):
125    """Function to create dataset and dataloaders for the test dataset.
126    Rewritten simpler version of ```make_criteo_and_loaders()``` from the dlrm_data_pytorch.py
127    that makes the test dataset and dataloaders only for the ***kaggle criteo dataset***
128    """
129    test_data = CriteoDataset(
130        "kaggle",
131        -1,
132        0.0,
133        "total",
134        "test",
135        raw_data_file_path,
136        processed_data_file,
137        False,
138        False,
139    )
140    test_loader = torch.utils.data.DataLoader(
141        test_data,
142        batch_size=16384,
143        shuffle=False,
144        num_workers=7,
145        collate_fn=collate_wrapper_criteo_offset,
146        pin_memory=False,
147        drop_last=False,
148    )
149    return test_loader
150
151
152def fetch_model(model_path, device, sparse_dlrm=False):
153    """This function unzips the zipped model checkpoint (if zipped) and returns a
154    model object
155
156    Args:
157        model_path (str)
158            path pointing to the zipped/raw model checkpoint file that was dumped in evaluate disk savings
159        device (torch.device)
160            device to which model needs to be loaded to
161    """
162    if zipfile.is_zipfile(model_path):
163        with zipfile.ZipFile(model_path, "r", zipfile.ZIP_DEFLATED) as zip_ref:
164            zip_ref.extractall(os.path.dirname(model_path))
165            unzip_path = model_path.replace(".zip", ".ckpt")
166    else:
167        unzip_path = model_path
168
169    model = get_dlrm_model(sparse_dlrm=sparse_dlrm)
170    model.load_state_dict(torch.load(unzip_path, map_location=device))
171    model = model.to(device)
172    model.eval()
173
174    # If there was a zip file, clean up the unzipped files
175    if zipfile.is_zipfile(model_path):
176        os.remove(unzip_path)
177
178    return model
179