1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs 2*da0073e9SAndroid Build Coastguard Workerimport types 3*da0073e9SAndroid Build Coastguard Worker 4*da0073e9SAndroid Build Coastguard Workerimport torch._C 5*da0073e9SAndroid Build Coastguard Worker 6*da0073e9SAndroid Build Coastguard Worker 7*da0073e9SAndroid Build Coastguard Workerclass _ClassNamespace(types.ModuleType): 8*da0073e9SAndroid Build Coastguard Worker def __init__(self, name): 9*da0073e9SAndroid Build Coastguard Worker super().__init__("torch.classes" + name) 10*da0073e9SAndroid Build Coastguard Worker self.name = name 11*da0073e9SAndroid Build Coastguard Worker 12*da0073e9SAndroid Build Coastguard Worker def __getattr__(self, attr): 13*da0073e9SAndroid Build Coastguard Worker proxy = torch._C._get_custom_class_python_wrapper(self.name, attr) 14*da0073e9SAndroid Build Coastguard Worker if proxy is None: 15*da0073e9SAndroid Build Coastguard Worker raise RuntimeError(f"Class {self.name}.{attr} not registered!") 16*da0073e9SAndroid Build Coastguard Worker return proxy 17*da0073e9SAndroid Build Coastguard Worker 18*da0073e9SAndroid Build Coastguard Worker 19*da0073e9SAndroid Build Coastguard Workerclass _Classes(types.ModuleType): 20*da0073e9SAndroid Build Coastguard Worker __file__ = "_classes.py" 21*da0073e9SAndroid Build Coastguard Worker 22*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 23*da0073e9SAndroid Build Coastguard Worker super().__init__("torch.classes") 24*da0073e9SAndroid Build Coastguard Worker 25*da0073e9SAndroid Build Coastguard Worker def __getattr__(self, name): 26*da0073e9SAndroid Build Coastguard Worker namespace = _ClassNamespace(name) 27*da0073e9SAndroid Build Coastguard Worker setattr(self, name, namespace) 28*da0073e9SAndroid Build Coastguard Worker return namespace 29*da0073e9SAndroid Build Coastguard Worker 30*da0073e9SAndroid Build Coastguard Worker @property 31*da0073e9SAndroid Build Coastguard Worker def loaded_libraries(self): 32*da0073e9SAndroid Build Coastguard Worker return torch.ops.loaded_libraries 33*da0073e9SAndroid Build Coastguard Worker 34*da0073e9SAndroid Build Coastguard Worker def load_library(self, path): 35*da0073e9SAndroid Build Coastguard Worker """ 36*da0073e9SAndroid Build Coastguard Worker Loads a shared library from the given path into the current process. 37*da0073e9SAndroid Build Coastguard Worker 38*da0073e9SAndroid Build Coastguard Worker The library being loaded may run global initialization code to register 39*da0073e9SAndroid Build Coastguard Worker custom classes with the PyTorch JIT runtime. This allows dynamically 40*da0073e9SAndroid Build Coastguard Worker loading custom classes. For this, you should compile your class 41*da0073e9SAndroid Build Coastguard Worker and the static registration code into a shared library object, and then 42*da0073e9SAndroid Build Coastguard Worker call ``torch.classes.load_library('path/to/libcustom.so')`` to load the 43*da0073e9SAndroid Build Coastguard Worker shared object. 44*da0073e9SAndroid Build Coastguard Worker 45*da0073e9SAndroid Build Coastguard Worker After the library is loaded, it is added to the 46*da0073e9SAndroid Build Coastguard Worker ``torch.classes.loaded_libraries`` attribute, a set that may be inspected 47*da0073e9SAndroid Build Coastguard Worker for the paths of all libraries loaded using this function. 48*da0073e9SAndroid Build Coastguard Worker 49*da0073e9SAndroid Build Coastguard Worker Args: 50*da0073e9SAndroid Build Coastguard Worker path (str): A path to a shared library to load. 51*da0073e9SAndroid Build Coastguard Worker """ 52*da0073e9SAndroid Build Coastguard Worker torch.ops.load_library(path) 53*da0073e9SAndroid Build Coastguard Worker 54*da0073e9SAndroid Build Coastguard Worker 55*da0073e9SAndroid Build Coastguard Worker# The classes "namespace" 56*da0073e9SAndroid Build Coastguard Workerclasses = _Classes() 57