xref: /aosp_15_r20/external/pytorch/torch/utils/_content_store.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2# This module provides a FAST (on GPU) content addressable store for storages
3# (and tensors on top of them) with VERY WEAK portability guarantees (e.g.,
4# don't expect CPU/CUDA to address to the same hash, don't expect it to be
5# portable across devices) that is NOT cryptographically secure.  In return,
6# we are able to hash 40G of tensor data on GPU in less than a second,
7# compared to running SHA-1 in CPU which would a minute or so.  The primary
8# use case is for efficiently snapshotting intermediate tensor data for
9# offline debugging, but it's been put in this module in case you think of
10# another use case for it.  The hash function could be replaced with a
11# straight reimplementation of SHA-1, which would give us much stronger
12# portability guarantees.
13#
14# WARNING: THERE IS NO BC/FC GUARANTEE FOR THIS FORMAT!  If you need to format
15# shift the result, consider packing it into a single torch.save object
16# with traditional view sharing.
17#
18# Because of the weak portability guarantees, you can only write to the
19# content store from a single process; we don't provide any capability
20# of "reopening" a content store to add more things to it.  But we don't
21# assume that you can keep all of the tensors you want to add to the store
22# in memory at once, because you probably can't!  Nor do we assume that
23# you know a priori whether or not two storages can be deduplicated or not.
24#
25# Note: only storages are content-addressed; tensors are name addressed
26#
27# Note: our padding strategy means that [1, 0] and [1] int16 tensors would
28# map to the same (padded) storage.  We think this will be immaterial for most
29# users.
30
31import ctypes
32import functools
33import hashlib
34import os.path
35import struct
36from collections import defaultdict
37from typing import Dict, Optional, Set
38
39import torch
40import torch._prims as prims
41import torch._utils
42import torch.nn.functional as F
43from torch._C import default_generator
44from torch.multiprocessing.reductions import StorageWeakRef
45
46
47def lazy_compile(**compile_kwargs):
48    """Lazily wrap a function with torch.compile on the first call
49
50    This avoids eagerly importing dynamo.
51    """
52
53    def decorate_fn(fn):
54        @functools.wraps(fn)
55        def compile_hook(*args, **kwargs):
56            compiled_fn = torch.compile(fn, **compile_kwargs)
57            globals()[fn.__name__] = functools.wraps(fn)(compiled_fn)
58            return compiled_fn(*args, **kwargs)
59
60        return compile_hook
61
62    return decorate_fn
63
64
65# Use of torch.compile is mandatory for (1) good memory usage
66# and (2) xor_sum implementation.  This is our first instance of
67# using PT2 to implement a kernel in PyTorch; if we get AOT capabilities
68# it would be good to apply it here.
69@lazy_compile(dynamic=True)
70def hash_storage_kernel(x):
71    # The randint calls are carefully written to hit things we
72    # have lowerings for in inductor.  Lack of unsigned 32-bit integer
73    # is a pain.
74    a = torch.randint(
75        -(2**31), 2**31, x.shape, device=x.device, dtype=torch.int32
76    ).abs()
77    a = ((a % (2**31 - 1)) + 1).long()
78    b = (
79        torch.randint(-(2**31), 2**31, x.shape, device=x.device, dtype=torch.int32)
80        .abs()
81        .long()
82    )
83    # This is a standard shift-multiply universal hash family
84    # plus xor sum hash, using Philox to generate random numbers.
85    # Our Philox RNG is not deterministic across devices so
86    # don't use this for stable hashing.
87    #
88    # This assumes fixed length so you're also obligated to bucket
89    # by the length of tensor as well
90    return prims.xor_sum((a * x + b).int(), [0])
91
92
93# Returns a hex digest of the data in the storage.  Guaranteed to be
94# SHA-1 if stable_hash=True, otherwise it will consistent for a single
95# process run but not necessarily across processes.
96def hash_storage(storage: torch.UntypedStorage, *, stable_hash: bool = False) -> str:
97    import torch._dynamo
98    from torch._dynamo.utils import is_compile_supported
99
100    device_type = storage.device.type
101    if stable_hash or not is_compile_supported(device_type):
102        cpu_storage = storage.cpu()
103        # TODO: make storage support buffer protocol so this isn't
104        # necessary
105        buf = (ctypes.c_byte * cpu_storage.nbytes()).from_address(
106            cpu_storage.data_ptr()
107        )
108        sha1 = hashlib.sha1()
109        sha1.update(buf)
110        return sha1.hexdigest()
111
112    # TODO: factor this into a random utility
113    if device_type == "cpu":
114        generator = default_generator
115    elif device_type == "cuda":
116        import torch.cuda
117
118        generator = torch.cuda.default_generators[storage.device.index]
119    else:
120        raise AssertionError(f"unhandled device type {device_type}")
121    state = generator.get_state()
122    try:
123        generator.manual_seed(0)
124        x = torch.empty(0, dtype=torch.uint8, device=storage.device).set_(storage)  # type: ignore[call-overload]
125        # The dtype-casting view cannot be compiled, and so the
126        # padding/reshaping also needs to be done externally even
127        # though it could be profitably fused
128        pad = -x.numel() % 4
129        if pad > 0:
130            x = F.pad(x, (0, pad), "constant", 0)
131        x = x.view(torch.int32)
132        # We run the 32-bit hash five times with differing parameters to
133        # reduce chance of collision
134        ITER = 5
135        cs = [hash_storage_kernel(x).item() for _ in range(ITER)]
136        return struct.pack(">" + "i" * ITER, *cs).hex()
137    finally:
138        generator.set_state(state)
139
140
141class ContentStoreWriter:
142    # Structure:
143    #   storages/
144    #     00/
145    #       0000..00
146    #   tensors/
147    #     name
148    def __init__(self, loc: str, stable_hash: bool = False) -> None:
149        self.loc: str = loc
150        self.seen_storage_hashes: Set[str] = set()
151        self.stable_hash = stable_hash
152
153    # TODO: offer some sort of non-blocking API to speed things up
154    def write_storage(self, storage: torch.UntypedStorage) -> str:
155        h = hash_storage(storage, stable_hash=self.stable_hash)
156        if h in self.seen_storage_hashes:
157            return h
158        # TODO: consider not using torch.save for this; we don't actually
159        # need any metadata for the storage
160        subfolder = os.path.join(self.loc, "storages")
161        os.makedirs(subfolder, exist_ok=True)
162        target = os.path.join(subfolder, h)
163        if os.path.exists(target):
164            return h
165        torch.save(storage, target)
166        self.seen_storage_hashes.add(h)
167        return h
168
169    def compute_tensor_metadata(self, t: torch.Tensor, h=None):
170        if h is None:
171            h = hash_storage(t.untyped_storage(), stable_hash=self.stable_hash)
172        return (
173            t.dtype,
174            h,
175            t.storage_offset(),
176            tuple(t.shape),
177            t.stride(),
178            torch._utils.get_tensor_metadata(t),
179        )
180
181    def write_tensor(self, name: str, t: torch.Tensor) -> None:
182        storage = t.untyped_storage()
183        h = self.write_storage(storage)
184        # TODO: Support more advanced snapshotting of requires_grad/grad/etc
185        d, f = os.path.split(name)
186        payload = self.compute_tensor_metadata(t, h=h)
187        subfolder = os.path.join(self.loc, "tensors", d)
188        os.makedirs(subfolder, exist_ok=True)
189        torch.save(payload, os.path.join(subfolder, f))
190
191
192class ContentStoreReader:
193    def __init__(self, loc: str, *, cache=True) -> None:
194        self.loc = loc
195        self.storage_cache: Optional[
196            Dict[Optional[torch.device], Dict[str, StorageWeakRef]]
197        ] = None
198        if cache:
199            self.storage_cache = defaultdict(dict)
200
201    def read_storage(self, h: str, *, device=None) -> torch.UntypedStorage:
202        if device is not None:
203            device = torch.device(device)
204        ws = (
205            self.storage_cache[device].get(h)
206            if self.storage_cache is not None
207            else None
208        )
209        s: Optional[torch.UntypedStorage]
210        if ws is not None:
211            s = torch.UntypedStorage._new_with_weak_ptr(ws.cdata)
212            if s is not None:
213                return s
214        s = torch.load(
215            os.path.join(self.loc, "storages", h),
216            weights_only=True,
217            map_location=device,
218        )._untyped_storage
219        assert s is not None
220        if self.storage_cache is not None:
221            self.storage_cache[device][h] = StorageWeakRef(s)
222        return s
223
224    def read_tensor_metadata(self, name: str):
225        fn = os.path.join(self.loc, "tensors", name)
226        if not os.path.exists(fn):
227            raise FileNotFoundError(fn)
228        return torch.load(fn, weights_only=True)
229
230    def read_tensor(self, name: str, *, device=None) -> torch.Tensor:
231        dtype, h, storage_offset, size, stride, metadata = self.read_tensor_metadata(
232            name
233        )
234        storage = self.read_storage(h, device=device)
235        t = torch.tensor([], dtype=dtype, device=storage.device)
236        t.set_(storage, storage_offset, size, stride)
237        torch._utils.set_tensor_metadata(t, metadata)
238        return t
239