xref: /aosp_15_r20/external/pytorch/torch/_classes.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs
2*da0073e9SAndroid Build Coastguard Workerimport types
3*da0073e9SAndroid Build Coastguard Worker
4*da0073e9SAndroid Build Coastguard Workerimport torch._C
5*da0073e9SAndroid Build Coastguard Worker
6*da0073e9SAndroid Build Coastguard Worker
7*da0073e9SAndroid Build Coastguard Workerclass _ClassNamespace(types.ModuleType):
8*da0073e9SAndroid Build Coastguard Worker    def __init__(self, name):
9*da0073e9SAndroid Build Coastguard Worker        super().__init__("torch.classes" + name)
10*da0073e9SAndroid Build Coastguard Worker        self.name = name
11*da0073e9SAndroid Build Coastguard Worker
12*da0073e9SAndroid Build Coastguard Worker    def __getattr__(self, attr):
13*da0073e9SAndroid Build Coastguard Worker        proxy = torch._C._get_custom_class_python_wrapper(self.name, attr)
14*da0073e9SAndroid Build Coastguard Worker        if proxy is None:
15*da0073e9SAndroid Build Coastguard Worker            raise RuntimeError(f"Class {self.name}.{attr} not registered!")
16*da0073e9SAndroid Build Coastguard Worker        return proxy
17*da0073e9SAndroid Build Coastguard Worker
18*da0073e9SAndroid Build Coastguard Worker
19*da0073e9SAndroid Build Coastguard Workerclass _Classes(types.ModuleType):
20*da0073e9SAndroid Build Coastguard Worker    __file__ = "_classes.py"
21*da0073e9SAndroid Build Coastguard Worker
22*da0073e9SAndroid Build Coastguard Worker    def __init__(self) -> None:
23*da0073e9SAndroid Build Coastguard Worker        super().__init__("torch.classes")
24*da0073e9SAndroid Build Coastguard Worker
25*da0073e9SAndroid Build Coastguard Worker    def __getattr__(self, name):
26*da0073e9SAndroid Build Coastguard Worker        namespace = _ClassNamespace(name)
27*da0073e9SAndroid Build Coastguard Worker        setattr(self, name, namespace)
28*da0073e9SAndroid Build Coastguard Worker        return namespace
29*da0073e9SAndroid Build Coastguard Worker
30*da0073e9SAndroid Build Coastguard Worker    @property
31*da0073e9SAndroid Build Coastguard Worker    def loaded_libraries(self):
32*da0073e9SAndroid Build Coastguard Worker        return torch.ops.loaded_libraries
33*da0073e9SAndroid Build Coastguard Worker
34*da0073e9SAndroid Build Coastguard Worker    def load_library(self, path):
35*da0073e9SAndroid Build Coastguard Worker        """
36*da0073e9SAndroid Build Coastguard Worker        Loads a shared library from the given path into the current process.
37*da0073e9SAndroid Build Coastguard Worker
38*da0073e9SAndroid Build Coastguard Worker        The library being loaded may run global initialization code to register
39*da0073e9SAndroid Build Coastguard Worker        custom classes with the PyTorch JIT runtime. This allows dynamically
40*da0073e9SAndroid Build Coastguard Worker        loading custom classes. For this, you should compile your class
41*da0073e9SAndroid Build Coastguard Worker        and the static registration code into a shared library object, and then
42*da0073e9SAndroid Build Coastguard Worker        call ``torch.classes.load_library('path/to/libcustom.so')`` to load the
43*da0073e9SAndroid Build Coastguard Worker        shared object.
44*da0073e9SAndroid Build Coastguard Worker
45*da0073e9SAndroid Build Coastguard Worker        After the library is loaded, it is added to the
46*da0073e9SAndroid Build Coastguard Worker        ``torch.classes.loaded_libraries`` attribute, a set that may be inspected
47*da0073e9SAndroid Build Coastguard Worker        for the paths of all libraries loaded using this function.
48*da0073e9SAndroid Build Coastguard Worker
49*da0073e9SAndroid Build Coastguard Worker        Args:
50*da0073e9SAndroid Build Coastguard Worker            path (str): A path to a shared library to load.
51*da0073e9SAndroid Build Coastguard Worker        """
52*da0073e9SAndroid Build Coastguard Worker        torch.ops.load_library(path)
53*da0073e9SAndroid Build Coastguard Worker
54*da0073e9SAndroid Build Coastguard Worker
55*da0073e9SAndroid Build Coastguard Worker# The classes "namespace"
56*da0073e9SAndroid Build Coastguard Workerclasses = _Classes()
57