xref: /aosp_15_r20/external/pytorch/torch/testing/_internal/torchbind_impls.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import contextlib
3from typing import Optional
4
5import torch
6
7
8_TORCHBIND_IMPLS_INITIALIZED = False
9
10_TENSOR_QUEUE_GLOBAL_TEST: Optional[torch.ScriptObject] = None
11
12
13def init_torchbind_implementations():
14    global _TORCHBIND_IMPLS_INITIALIZED
15    global _TENSOR_QUEUE_GLOBAL_TEST
16    if _TORCHBIND_IMPLS_INITIALIZED:
17        return
18
19    load_torchbind_test_lib()
20    register_fake_operators()
21    register_fake_classes()
22    _TENSOR_QUEUE_GLOBAL_TEST = _empty_tensor_queue()
23    _TORCHBIND_IMPLS_INITIALIZED = True
24
25
26def _empty_tensor_queue() -> torch.ScriptObject:
27    return torch.classes._TorchScriptTesting._TensorQueue(
28        torch.empty(
29            0,
30        ).fill_(-1)
31    )
32
33
34# put these under a function because the corresponding library might not be loaded yet.
35def register_fake_operators():
36    @torch.library.register_fake("_TorchScriptTesting::takes_foo_python_meta")
37    def fake_takes_foo(foo, z):
38        return foo.add_tensor(z)
39
40    @torch.library.register_fake("_TorchScriptTesting::queue_pop")
41    def fake_queue_pop(tq):
42        return tq.pop()
43
44    @torch.library.register_fake("_TorchScriptTesting::queue_push")
45    def fake_queue_push(tq, x):
46        return tq.push(x)
47
48    @torch.library.register_fake("_TorchScriptTesting::queue_size")
49    def fake_queue_size(tq):
50        return tq.size()
51
52    def meta_takes_foo_list_return(foo, x):
53        a = foo.add_tensor(x)
54        b = foo.add_tensor(a)
55        c = foo.add_tensor(b)
56        return [a, b, c]
57
58    def meta_takes_foo_tuple_return(foo, x):
59        a = foo.add_tensor(x)
60        b = foo.add_tensor(a)
61        return (a, b)
62
63    torch.ops._TorchScriptTesting.takes_foo_list_return.default.py_impl(
64        torch._C.DispatchKey.Meta
65    )(meta_takes_foo_list_return)
66
67    torch.ops._TorchScriptTesting.takes_foo_tuple_return.default.py_impl(
68        torch._C.DispatchKey.Meta
69    )(meta_takes_foo_tuple_return)
70
71    torch.ops._TorchScriptTesting.takes_foo.default.py_impl(torch._C.DispatchKey.Meta)(
72        # make signature match original cpp implementation to support kwargs
73        lambda foo, x: foo.add_tensor(x)
74    )
75
76
77def register_fake_classes():
78    @torch._library.register_fake_class("_TorchScriptTesting::_Foo")
79    class FakeFoo:
80        def __init__(self, x: int, y: int):
81            self.x = x
82            self.y = y
83
84        @classmethod
85        def __obj_unflatten__(cls, flattend_foo):
86            return cls(**dict(flattend_foo))
87
88        def add_tensor(self, z):
89            return (self.x + self.y) * z
90
91    @torch._library.register_fake_class("_TorchScriptTesting::_ContainsTensor")
92    class FakeContainsTensor:
93        def __init__(self, t: torch.Tensor):
94            self.t = t
95
96        @classmethod
97        def __obj_unflatten__(cls, flattend_foo):
98            return cls(**dict(flattend_foo))
99
100        def get(self):
101            return self.t
102
103
104def load_torchbind_test_lib():
105    import unittest
106
107    from torch.testing._internal.common_utils import (  # type: ignore[attr-defined]
108        find_library_location,
109        IS_FBCODE,
110        IS_MACOS,
111        IS_SANDCASTLE,
112        IS_WINDOWS,
113    )
114
115    if IS_SANDCASTLE or IS_FBCODE:
116        torch.ops.load_library("//caffe2/test/cpp/jit:test_custom_class_registrations")
117    elif IS_MACOS:
118        raise unittest.SkipTest("non-portable load_library call used in test")
119    else:
120        lib_file_path = find_library_location("libtorchbind_test.so")
121        if IS_WINDOWS:
122            lib_file_path = find_library_location("torchbind_test.dll")
123        torch.ops.load_library(str(lib_file_path))
124
125
126@contextlib.contextmanager
127def _register_py_impl_temporarily(op_overload, key, fn):
128    try:
129        op_overload.py_impl(key)(fn)
130        yield
131    finally:
132        del op_overload.py_kernels[key]
133        op_overload._dispatch_cache.clear()
134