1# mypy: allow-untyped-defs 2# This module provides a FAST (on GPU) content addressable store for storages 3# (and tensors on top of them) with VERY WEAK portability guarantees (e.g., 4# don't expect CPU/CUDA to address to the same hash, don't expect it to be 5# portable across devices) that is NOT cryptographically secure. In return, 6# we are able to hash 40G of tensor data on GPU in less than a second, 7# compared to running SHA-1 in CPU which would a minute or so. The primary 8# use case is for efficiently snapshotting intermediate tensor data for 9# offline debugging, but it's been put in this module in case you think of 10# another use case for it. The hash function could be replaced with a 11# straight reimplementation of SHA-1, which would give us much stronger 12# portability guarantees. 13# 14# WARNING: THERE IS NO BC/FC GUARANTEE FOR THIS FORMAT! If you need to format 15# shift the result, consider packing it into a single torch.save object 16# with traditional view sharing. 17# 18# Because of the weak portability guarantees, you can only write to the 19# content store from a single process; we don't provide any capability 20# of "reopening" a content store to add more things to it. But we don't 21# assume that you can keep all of the tensors you want to add to the store 22# in memory at once, because you probably can't! Nor do we assume that 23# you know a priori whether or not two storages can be deduplicated or not. 24# 25# Note: only storages are content-addressed; tensors are name addressed 26# 27# Note: our padding strategy means that [1, 0] and [1] int16 tensors would 28# map to the same (padded) storage. We think this will be immaterial for most 29# users. 30 31import ctypes 32import functools 33import hashlib 34import os.path 35import struct 36from collections import defaultdict 37from typing import Dict, Optional, Set 38 39import torch 40import torch._prims as prims 41import torch._utils 42import torch.nn.functional as F 43from torch._C import default_generator 44from torch.multiprocessing.reductions import StorageWeakRef 45 46 47def lazy_compile(**compile_kwargs): 48 """Lazily wrap a function with torch.compile on the first call 49 50 This avoids eagerly importing dynamo. 51 """ 52 53 def decorate_fn(fn): 54 @functools.wraps(fn) 55 def compile_hook(*args, **kwargs): 56 compiled_fn = torch.compile(fn, **compile_kwargs) 57 globals()[fn.__name__] = functools.wraps(fn)(compiled_fn) 58 return compiled_fn(*args, **kwargs) 59 60 return compile_hook 61 62 return decorate_fn 63 64 65# Use of torch.compile is mandatory for (1) good memory usage 66# and (2) xor_sum implementation. This is our first instance of 67# using PT2 to implement a kernel in PyTorch; if we get AOT capabilities 68# it would be good to apply it here. 69@lazy_compile(dynamic=True) 70def hash_storage_kernel(x): 71 # The randint calls are carefully written to hit things we 72 # have lowerings for in inductor. Lack of unsigned 32-bit integer 73 # is a pain. 74 a = torch.randint( 75 -(2**31), 2**31, x.shape, device=x.device, dtype=torch.int32 76 ).abs() 77 a = ((a % (2**31 - 1)) + 1).long() 78 b = ( 79 torch.randint(-(2**31), 2**31, x.shape, device=x.device, dtype=torch.int32) 80 .abs() 81 .long() 82 ) 83 # This is a standard shift-multiply universal hash family 84 # plus xor sum hash, using Philox to generate random numbers. 85 # Our Philox RNG is not deterministic across devices so 86 # don't use this for stable hashing. 87 # 88 # This assumes fixed length so you're also obligated to bucket 89 # by the length of tensor as well 90 return prims.xor_sum((a * x + b).int(), [0]) 91 92 93# Returns a hex digest of the data in the storage. Guaranteed to be 94# SHA-1 if stable_hash=True, otherwise it will consistent for a single 95# process run but not necessarily across processes. 96def hash_storage(storage: torch.UntypedStorage, *, stable_hash: bool = False) -> str: 97 import torch._dynamo 98 from torch._dynamo.utils import is_compile_supported 99 100 device_type = storage.device.type 101 if stable_hash or not is_compile_supported(device_type): 102 cpu_storage = storage.cpu() 103 # TODO: make storage support buffer protocol so this isn't 104 # necessary 105 buf = (ctypes.c_byte * cpu_storage.nbytes()).from_address( 106 cpu_storage.data_ptr() 107 ) 108 sha1 = hashlib.sha1() 109 sha1.update(buf) 110 return sha1.hexdigest() 111 112 # TODO: factor this into a random utility 113 if device_type == "cpu": 114 generator = default_generator 115 elif device_type == "cuda": 116 import torch.cuda 117 118 generator = torch.cuda.default_generators[storage.device.index] 119 else: 120 raise AssertionError(f"unhandled device type {device_type}") 121 state = generator.get_state() 122 try: 123 generator.manual_seed(0) 124 x = torch.empty(0, dtype=torch.uint8, device=storage.device).set_(storage) # type: ignore[call-overload] 125 # The dtype-casting view cannot be compiled, and so the 126 # padding/reshaping also needs to be done externally even 127 # though it could be profitably fused 128 pad = -x.numel() % 4 129 if pad > 0: 130 x = F.pad(x, (0, pad), "constant", 0) 131 x = x.view(torch.int32) 132 # We run the 32-bit hash five times with differing parameters to 133 # reduce chance of collision 134 ITER = 5 135 cs = [hash_storage_kernel(x).item() for _ in range(ITER)] 136 return struct.pack(">" + "i" * ITER, *cs).hex() 137 finally: 138 generator.set_state(state) 139 140 141class ContentStoreWriter: 142 # Structure: 143 # storages/ 144 # 00/ 145 # 0000..00 146 # tensors/ 147 # name 148 def __init__(self, loc: str, stable_hash: bool = False) -> None: 149 self.loc: str = loc 150 self.seen_storage_hashes: Set[str] = set() 151 self.stable_hash = stable_hash 152 153 # TODO: offer some sort of non-blocking API to speed things up 154 def write_storage(self, storage: torch.UntypedStorage) -> str: 155 h = hash_storage(storage, stable_hash=self.stable_hash) 156 if h in self.seen_storage_hashes: 157 return h 158 # TODO: consider not using torch.save for this; we don't actually 159 # need any metadata for the storage 160 subfolder = os.path.join(self.loc, "storages") 161 os.makedirs(subfolder, exist_ok=True) 162 target = os.path.join(subfolder, h) 163 if os.path.exists(target): 164 return h 165 torch.save(storage, target) 166 self.seen_storage_hashes.add(h) 167 return h 168 169 def compute_tensor_metadata(self, t: torch.Tensor, h=None): 170 if h is None: 171 h = hash_storage(t.untyped_storage(), stable_hash=self.stable_hash) 172 return ( 173 t.dtype, 174 h, 175 t.storage_offset(), 176 tuple(t.shape), 177 t.stride(), 178 torch._utils.get_tensor_metadata(t), 179 ) 180 181 def write_tensor(self, name: str, t: torch.Tensor) -> None: 182 storage = t.untyped_storage() 183 h = self.write_storage(storage) 184 # TODO: Support more advanced snapshotting of requires_grad/grad/etc 185 d, f = os.path.split(name) 186 payload = self.compute_tensor_metadata(t, h=h) 187 subfolder = os.path.join(self.loc, "tensors", d) 188 os.makedirs(subfolder, exist_ok=True) 189 torch.save(payload, os.path.join(subfolder, f)) 190 191 192class ContentStoreReader: 193 def __init__(self, loc: str, *, cache=True) -> None: 194 self.loc = loc 195 self.storage_cache: Optional[ 196 Dict[Optional[torch.device], Dict[str, StorageWeakRef]] 197 ] = None 198 if cache: 199 self.storage_cache = defaultdict(dict) 200 201 def read_storage(self, h: str, *, device=None) -> torch.UntypedStorage: 202 if device is not None: 203 device = torch.device(device) 204 ws = ( 205 self.storage_cache[device].get(h) 206 if self.storage_cache is not None 207 else None 208 ) 209 s: Optional[torch.UntypedStorage] 210 if ws is not None: 211 s = torch.UntypedStorage._new_with_weak_ptr(ws.cdata) 212 if s is not None: 213 return s 214 s = torch.load( 215 os.path.join(self.loc, "storages", h), 216 weights_only=True, 217 map_location=device, 218 )._untyped_storage 219 assert s is not None 220 if self.storage_cache is not None: 221 self.storage_cache[device][h] = StorageWeakRef(s) 222 return s 223 224 def read_tensor_metadata(self, name: str): 225 fn = os.path.join(self.loc, "tensors", name) 226 if not os.path.exists(fn): 227 raise FileNotFoundError(fn) 228 return torch.load(fn, weights_only=True) 229 230 def read_tensor(self, name: str, *, device=None) -> torch.Tensor: 231 dtype, h, storage_offset, size, stride, metadata = self.read_tensor_metadata( 232 name 233 ) 234 storage = self.read_storage(h, device=device) 235 t = torch.tensor([], dtype=dtype, device=storage.device) 236 t.set_(storage, storage_offset, size, stride) 237 torch._utils.set_tensor_metadata(t, metadata) 238 return t 239