1# mypy: allow-untyped-defs 2import contextlib 3from typing import Optional 4 5import torch 6from torch.utils._content_store import ContentStoreReader 7 8 9LOAD_TENSOR_READER: Optional[ContentStoreReader] = None 10 11 12@contextlib.contextmanager 13def load_tensor_reader(loc): 14 global LOAD_TENSOR_READER 15 assert LOAD_TENSOR_READER is None 16 # load_tensor is an "op", and we will play merry hell on 17 # Inductor's memory planning if we return a tensor that 18 # aliases another tensor that we previously returned from 19 # an operator. So unlike standard ContentStoreReader use, 20 # we disable the cache so that you always get fresh storages 21 # (no aliasing for you!) 22 LOAD_TENSOR_READER = ContentStoreReader(loc, cache=False) 23 try: 24 yield 25 finally: 26 LOAD_TENSOR_READER = None 27 28 29def register_debug_prims(): 30 torch.library.define( 31 "debugprims::load_tensor", 32 "(str name, int[] size, int[] stride, *, ScalarType dtype, Device device) -> Tensor", 33 ) 34 35 @torch.library.impl("debugprims::load_tensor", "BackendSelect") 36 def load_tensor_factory(name, size, stride, dtype, device): 37 if LOAD_TENSOR_READER is None: 38 from torch._dynamo.testing import rand_strided 39 40 return rand_strided(size, stride, dtype, device) 41 else: 42 from torch._dynamo.utils import clone_input 43 44 # device argument here takes care of coercion 45 r = LOAD_TENSOR_READER.read_tensor(name, device=device) 46 assert list(r.size()) == size, f"{r.size()} != {size}" 47 assert list(r.stride()) == stride, f"{r.stride()} != {stride}" 48 assert r.device == device, f"{r.device} != {device}" 49 50 # Unlike the other properties, we will do coercions for dtype 51 # mismatch 52 if r.dtype != dtype: 53 r = clone_input(r, dtype=dtype) 54 return r 55