xref: /aosp_15_r20/external/pytorch/torch/_prims/debug_prims.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import contextlib
3from typing import Optional
4
5import torch
6from torch.utils._content_store import ContentStoreReader
7
8
9LOAD_TENSOR_READER: Optional[ContentStoreReader] = None
10
11
12@contextlib.contextmanager
13def load_tensor_reader(loc):
14    global LOAD_TENSOR_READER
15    assert LOAD_TENSOR_READER is None
16    # load_tensor is an "op", and we will play merry hell on
17    # Inductor's memory planning if we return a tensor that
18    # aliases another tensor that we previously returned from
19    # an operator.  So unlike standard ContentStoreReader use,
20    # we disable the cache so that you always get fresh storages
21    # (no aliasing for you!)
22    LOAD_TENSOR_READER = ContentStoreReader(loc, cache=False)
23    try:
24        yield
25    finally:
26        LOAD_TENSOR_READER = None
27
28
29def register_debug_prims():
30    torch.library.define(
31        "debugprims::load_tensor",
32        "(str name, int[] size, int[] stride, *, ScalarType dtype, Device device) -> Tensor",
33    )
34
35    @torch.library.impl("debugprims::load_tensor", "BackendSelect")
36    def load_tensor_factory(name, size, stride, dtype, device):
37        if LOAD_TENSOR_READER is None:
38            from torch._dynamo.testing import rand_strided
39
40            return rand_strided(size, stride, dtype, device)
41        else:
42            from torch._dynamo.utils import clone_input
43
44            # device argument here takes care of coercion
45            r = LOAD_TENSOR_READER.read_tensor(name, device=device)
46            assert list(r.size()) == size, f"{r.size()} != {size}"
47            assert list(r.stride()) == stride, f"{r.stride()} != {stride}"
48            assert r.device == device, f"{r.device} != {device}"
49
50            # Unlike the other properties, we will do coercions for dtype
51            # mismatch
52            if r.dtype != dtype:
53                r = clone_input(r, dtype=dtype)
54            return r
55