1# mypy: allow-untyped-defs 2import os 3import zipfile 4 5import numpy as np # type: ignore[import] 6from dlrm_data_pytorch import ( # type: ignore[import] 7 collate_wrapper_criteo_offset, 8 CriteoDataset, 9) 10from dlrm_s_pytorch import DLRM_Net # type: ignore[import] 11 12import torch 13 14 15class SparseDLRM(DLRM_Net): 16 """The SparseDLRM model is a wrapper around the DLRM_Net model that tries 17 to use torch.sparse tensors for the features obtained after the ```interact_features()``` 18 call. The idea is to do a simple torch.mm() with the weight matrix of the first linear 19 layer of the top layer. 20 """ 21 22 def __init__(self, **args): 23 super().__init__(**args) 24 25 def forward(self, dense_x, lS_o, lS_i): 26 x = self.apply_mlp(dense_x, self.bot_l) # dense features 27 ly = self.apply_emb(lS_o, lS_i, self.emb_l, self.v_W_l) # apply embedding bag 28 z = self.interact_features(x, ly) 29 30 z = z.to_sparse_coo() 31 z = torch.mm(z, self.top_l[0].weight.T).add(self.top_l[0].bias) 32 for layer in self.top_l[1:]: 33 z = layer(z) 34 35 return z 36 37 38def get_valid_name(name): 39 """Replaces '.' with '_' as names with '.' are invalid in data sparsifier""" 40 return name.replace(".", "_") 41 42 43def get_dlrm_model(sparse_dlrm=False): 44 """Obtain dlrm model. The configs specified are based on the script in 45 bench/dlrm_s_criteo_kaggle.sh. The same config is used to train the model 46 for benchmarking on data sparsifier. 47 """ 48 dlrm_model_config = { 49 "m_spa": 16, 50 "ln_emb": np.array( 51 [ 52 1460, 53 583, 54 10131227, 55 2202608, 56 305, 57 24, 58 12517, 59 633, 60 3, 61 93145, 62 5683, 63 8351593, 64 3194, 65 27, 66 14992, 67 5461306, 68 10, 69 5652, 70 2173, 71 4, 72 7046547, 73 18, 74 15, 75 286181, 76 105, 77 142572, 78 ], 79 dtype=np.int32, 80 ), 81 "ln_bot": np.array([13, 512, 256, 64, 16]), 82 "ln_top": np.array([367, 512, 256, 1]), 83 "arch_interaction_op": "dot", 84 "arch_interaction_itself": False, 85 "sigmoid_bot": -1, 86 "sigmoid_top": 2, 87 "sync_dense_params": True, 88 "loss_threshold": 0.0, 89 "ndevices": 1, 90 "qr_flag": False, 91 "qr_operation": "mult", 92 "qr_collisions": 4, 93 "qr_threshold": 200, 94 "md_flag": False, 95 "md_threshold": 200, 96 "weighted_pooling": None, 97 "loss_function": "bce", 98 } 99 if sparse_dlrm: 100 dlrm_model = SparseDLRM(**dlrm_model_config) 101 else: 102 dlrm_model = DLRM_Net(**dlrm_model_config) 103 return dlrm_model 104 105 106def dlrm_wrap(X, lS_o, lS_i, device, ndevices=1): 107 """Rewritten simpler version of ```dlrm_wrap()``` found in dlrm_s_pytorch.py. 108 This function simply moves the input tensors into the device and without the forward pass 109 """ 110 if ndevices == 1: 111 lS_i = ( 112 [S_i.to(device) for S_i in lS_i] 113 if isinstance(lS_i, list) 114 else lS_i.to(device) 115 ) 116 lS_o = ( 117 [S_o.to(device) for S_o in lS_o] 118 if isinstance(lS_o, list) 119 else lS_o.to(device) 120 ) 121 return X.to(device), lS_o, lS_i 122 123 124def make_test_data_loader(raw_data_file_path, processed_data_file): 125 """Function to create dataset and dataloaders for the test dataset. 126 Rewritten simpler version of ```make_criteo_and_loaders()``` from the dlrm_data_pytorch.py 127 that makes the test dataset and dataloaders only for the ***kaggle criteo dataset*** 128 """ 129 test_data = CriteoDataset( 130 "kaggle", 131 -1, 132 0.0, 133 "total", 134 "test", 135 raw_data_file_path, 136 processed_data_file, 137 False, 138 False, 139 ) 140 test_loader = torch.utils.data.DataLoader( 141 test_data, 142 batch_size=16384, 143 shuffle=False, 144 num_workers=7, 145 collate_fn=collate_wrapper_criteo_offset, 146 pin_memory=False, 147 drop_last=False, 148 ) 149 return test_loader 150 151 152def fetch_model(model_path, device, sparse_dlrm=False): 153 """This function unzips the zipped model checkpoint (if zipped) and returns a 154 model object 155 156 Args: 157 model_path (str) 158 path pointing to the zipped/raw model checkpoint file that was dumped in evaluate disk savings 159 device (torch.device) 160 device to which model needs to be loaded to 161 """ 162 if zipfile.is_zipfile(model_path): 163 with zipfile.ZipFile(model_path, "r", zipfile.ZIP_DEFLATED) as zip_ref: 164 zip_ref.extractall(os.path.dirname(model_path)) 165 unzip_path = model_path.replace(".zip", ".ckpt") 166 else: 167 unzip_path = model_path 168 169 model = get_dlrm_model(sparse_dlrm=sparse_dlrm) 170 model.load_state_dict(torch.load(unzip_path, map_location=device)) 171 model = model.to(device) 172 model.eval() 173 174 # If there was a zip file, clean up the unzipped files 175 if zipfile.is_zipfile(model_path): 176 os.remove(unzip_path) 177 178 return model 179