1# mypy: allow-untyped-defs 2import contextlib 3from typing import Optional 4 5import torch 6 7 8_TORCHBIND_IMPLS_INITIALIZED = False 9 10_TENSOR_QUEUE_GLOBAL_TEST: Optional[torch.ScriptObject] = None 11 12 13def init_torchbind_implementations(): 14 global _TORCHBIND_IMPLS_INITIALIZED 15 global _TENSOR_QUEUE_GLOBAL_TEST 16 if _TORCHBIND_IMPLS_INITIALIZED: 17 return 18 19 load_torchbind_test_lib() 20 register_fake_operators() 21 register_fake_classes() 22 _TENSOR_QUEUE_GLOBAL_TEST = _empty_tensor_queue() 23 _TORCHBIND_IMPLS_INITIALIZED = True 24 25 26def _empty_tensor_queue() -> torch.ScriptObject: 27 return torch.classes._TorchScriptTesting._TensorQueue( 28 torch.empty( 29 0, 30 ).fill_(-1) 31 ) 32 33 34# put these under a function because the corresponding library might not be loaded yet. 35def register_fake_operators(): 36 @torch.library.register_fake("_TorchScriptTesting::takes_foo_python_meta") 37 def fake_takes_foo(foo, z): 38 return foo.add_tensor(z) 39 40 @torch.library.register_fake("_TorchScriptTesting::queue_pop") 41 def fake_queue_pop(tq): 42 return tq.pop() 43 44 @torch.library.register_fake("_TorchScriptTesting::queue_push") 45 def fake_queue_push(tq, x): 46 return tq.push(x) 47 48 @torch.library.register_fake("_TorchScriptTesting::queue_size") 49 def fake_queue_size(tq): 50 return tq.size() 51 52 def meta_takes_foo_list_return(foo, x): 53 a = foo.add_tensor(x) 54 b = foo.add_tensor(a) 55 c = foo.add_tensor(b) 56 return [a, b, c] 57 58 def meta_takes_foo_tuple_return(foo, x): 59 a = foo.add_tensor(x) 60 b = foo.add_tensor(a) 61 return (a, b) 62 63 torch.ops._TorchScriptTesting.takes_foo_list_return.default.py_impl( 64 torch._C.DispatchKey.Meta 65 )(meta_takes_foo_list_return) 66 67 torch.ops._TorchScriptTesting.takes_foo_tuple_return.default.py_impl( 68 torch._C.DispatchKey.Meta 69 )(meta_takes_foo_tuple_return) 70 71 torch.ops._TorchScriptTesting.takes_foo.default.py_impl(torch._C.DispatchKey.Meta)( 72 # make signature match original cpp implementation to support kwargs 73 lambda foo, x: foo.add_tensor(x) 74 ) 75 76 77def register_fake_classes(): 78 @torch._library.register_fake_class("_TorchScriptTesting::_Foo") 79 class FakeFoo: 80 def __init__(self, x: int, y: int): 81 self.x = x 82 self.y = y 83 84 @classmethod 85 def __obj_unflatten__(cls, flattend_foo): 86 return cls(**dict(flattend_foo)) 87 88 def add_tensor(self, z): 89 return (self.x + self.y) * z 90 91 @torch._library.register_fake_class("_TorchScriptTesting::_ContainsTensor") 92 class FakeContainsTensor: 93 def __init__(self, t: torch.Tensor): 94 self.t = t 95 96 @classmethod 97 def __obj_unflatten__(cls, flattend_foo): 98 return cls(**dict(flattend_foo)) 99 100 def get(self): 101 return self.t 102 103 104def load_torchbind_test_lib(): 105 import unittest 106 107 from torch.testing._internal.common_utils import ( # type: ignore[attr-defined] 108 find_library_location, 109 IS_FBCODE, 110 IS_MACOS, 111 IS_SANDCASTLE, 112 IS_WINDOWS, 113 ) 114 115 if IS_SANDCASTLE or IS_FBCODE: 116 torch.ops.load_library("//caffe2/test/cpp/jit:test_custom_class_registrations") 117 elif IS_MACOS: 118 raise unittest.SkipTest("non-portable load_library call used in test") 119 else: 120 lib_file_path = find_library_location("libtorchbind_test.so") 121 if IS_WINDOWS: 122 lib_file_path = find_library_location("torchbind_test.dll") 123 torch.ops.load_library(str(lib_file_path)) 124 125 126@contextlib.contextmanager 127def _register_py_impl_temporarily(op_overload, key, fn): 128 try: 129 op_overload.py_impl(key)(fn) 130 yield 131 finally: 132 del op_overload.py_kernels[key] 133 op_overload._dispatch_cache.clear() 134