1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: custom-operators"] 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerimport collections 4*da0073e9SAndroid Build Coastguard Workerimport itertools 5*da0073e9SAndroid Build Coastguard Workerimport os 6*da0073e9SAndroid Build Coastguard Workerimport re 7*da0073e9SAndroid Build Coastguard Workerimport subprocess 8*da0073e9SAndroid Build Coastguard Workerimport sys 9*da0073e9SAndroid Build Coastguard Workerimport typing 10*da0073e9SAndroid Build Coastguard Workerimport unittest 11*da0073e9SAndroid Build Coastguard Workerfrom typing import * # noqa: F403 12*da0073e9SAndroid Build Coastguard Worker 13*da0073e9SAndroid Build Coastguard Workerimport numpy as np 14*da0073e9SAndroid Build Coastguard Worker 15*da0073e9SAndroid Build Coastguard Workerimport torch._custom_ops as custom_ops 16*da0073e9SAndroid Build Coastguard Workerimport torch.testing._internal.optests as optests 17*da0073e9SAndroid Build Coastguard Workerimport torch.utils._pytree as pytree 18*da0073e9SAndroid Build Coastguard Workerimport torch.utils.cpp_extension 19*da0073e9SAndroid Build Coastguard Workerfrom functorch import make_fx 20*da0073e9SAndroid Build Coastguard Workerfrom torch import Tensor 21*da0073e9SAndroid Build Coastguard Workerfrom torch._custom_op.impl import CustomOp, infer_schema 22*da0073e9SAndroid Build Coastguard Workerfrom torch._library.infer_schema import tuple_to_list 23*da0073e9SAndroid Build Coastguard Workerfrom torch._utils_internal import get_file_path_2 24*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal import custom_op_db 25*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_cuda import TEST_CUDA 26*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_device_type import ( 27*da0073e9SAndroid Build Coastguard Worker instantiate_device_type_tests, 28*da0073e9SAndroid Build Coastguard Worker OpDTypes, 29*da0073e9SAndroid Build Coastguard Worker ops, 30*da0073e9SAndroid Build Coastguard Worker) 31*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import ( 32*da0073e9SAndroid Build Coastguard Worker instantiate_parametrized_tests, 33*da0073e9SAndroid Build Coastguard Worker IS_WINDOWS, 34*da0073e9SAndroid Build Coastguard Worker parametrize, 35*da0073e9SAndroid Build Coastguard Worker run_tests, 36*da0073e9SAndroid Build Coastguard Worker skipIfTorchDynamo, 37*da0073e9SAndroid Build Coastguard Worker subtest, 38*da0073e9SAndroid Build Coastguard Worker TestCase, 39*da0073e9SAndroid Build Coastguard Worker) 40*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.custom_op_db import numpy_nonzero 41*da0073e9SAndroid Build Coastguard Worker 42*da0073e9SAndroid Build Coastguard Worker 43*da0073e9SAndroid Build Coastguard Worker# Shadowed by `torch.testing._internal.common_utils.custom_op` 44*da0073e9SAndroid Build Coastguard Workerfrom torch._custom_op.impl import custom_op # usort: skip 45*da0073e9SAndroid Build Coastguard Worker 46*da0073e9SAndroid Build Coastguard Worker 47*da0073e9SAndroid Build Coastguard Workerdef requires_compile(fun): 48*da0073e9SAndroid Build Coastguard Worker fun = unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")(fun) 49*da0073e9SAndroid Build Coastguard Worker return fun 50*da0073e9SAndroid Build Coastguard Worker 51*da0073e9SAndroid Build Coastguard Worker 52*da0073e9SAndroid Build Coastguard Workerclass CustomOpTestCaseBase(TestCase): 53*da0073e9SAndroid Build Coastguard Worker test_ns = "_test_custom_op" 54*da0073e9SAndroid Build Coastguard Worker 55*da0073e9SAndroid Build Coastguard Worker def setUp(self): 56*da0073e9SAndroid Build Coastguard Worker super().setUp() 57*da0073e9SAndroid Build Coastguard Worker self.libraries = [] 58*da0073e9SAndroid Build Coastguard Worker 59*da0073e9SAndroid Build Coastguard Worker def tearDown(self): 60*da0073e9SAndroid Build Coastguard Worker super().tearDown() 61*da0073e9SAndroid Build Coastguard Worker import torch._custom_op 62*da0073e9SAndroid Build Coastguard Worker 63*da0073e9SAndroid Build Coastguard Worker keys = list(torch._custom_op.impl.global_registry.keys()) 64*da0073e9SAndroid Build Coastguard Worker for key in keys: 65*da0073e9SAndroid Build Coastguard Worker if not key.startswith(f"{self.test_ns}::"): 66*da0073e9SAndroid Build Coastguard Worker continue 67*da0073e9SAndroid Build Coastguard Worker torch._custom_op.impl.global_registry[key]._destroy() 68*da0073e9SAndroid Build Coastguard Worker if hasattr(torch.ops, self.test_ns): 69*da0073e9SAndroid Build Coastguard Worker delattr(torch.ops, self.test_ns) 70*da0073e9SAndroid Build Coastguard Worker for lib in self.libraries: 71*da0073e9SAndroid Build Coastguard Worker lib._destroy() 72*da0073e9SAndroid Build Coastguard Worker del self.libraries 73*da0073e9SAndroid Build Coastguard Worker 74*da0073e9SAndroid Build Coastguard Worker def ns(self): 75*da0073e9SAndroid Build Coastguard Worker return getattr(torch.ops, self.test_ns) 76*da0073e9SAndroid Build Coastguard Worker 77*da0073e9SAndroid Build Coastguard Worker def lib(self): 78*da0073e9SAndroid Build Coastguard Worker result = torch.library.Library(self.test_ns, "FRAGMENT") # noqa: TOR901 79*da0073e9SAndroid Build Coastguard Worker self.libraries.append(result) 80*da0073e9SAndroid Build Coastguard Worker return result 81*da0073e9SAndroid Build Coastguard Worker 82*da0073e9SAndroid Build Coastguard Worker def get_op(self, qualname): 83*da0073e9SAndroid Build Coastguard Worker return torch._custom_op.impl.get_op(qualname) 84*da0073e9SAndroid Build Coastguard Worker 85*da0073e9SAndroid Build Coastguard Worker 86*da0073e9SAndroid Build Coastguard Worker@requires_compile 87*da0073e9SAndroid Build Coastguard Workerclass TestCustomOpTesting(CustomOpTestCaseBase): 88*da0073e9SAndroid Build Coastguard Worker @parametrize("check_gradients", (False, "auto")) 89*da0073e9SAndroid Build Coastguard Worker @parametrize("dynamic", (True, False)) 90*da0073e9SAndroid Build Coastguard Worker def test_aot_autograd_check_degenerate_cases( 91*da0073e9SAndroid Build Coastguard Worker self, device, dynamic, check_gradients 92*da0073e9SAndroid Build Coastguard Worker ): 93*da0073e9SAndroid Build Coastguard Worker def simple(x): 94*da0073e9SAndroid Build Coastguard Worker return x.clone() 95*da0073e9SAndroid Build Coastguard Worker 96*da0073e9SAndroid Build Coastguard Worker # Should not raise 97*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3, device=device) 98*da0073e9SAndroid Build Coastguard Worker optests.aot_autograd_check( 99*da0073e9SAndroid Build Coastguard Worker simple, (x,), {}, dynamic=dynamic, check_gradients=check_gradients 100*da0073e9SAndroid Build Coastguard Worker ) 101*da0073e9SAndroid Build Coastguard Worker 102*da0073e9SAndroid Build Coastguard Worker def outputs_dont_require_grad(x): 103*da0073e9SAndroid Build Coastguard Worker return x.detach() 104*da0073e9SAndroid Build Coastguard Worker 105*da0073e9SAndroid Build Coastguard Worker # Should not raise 106*da0073e9SAndroid Build Coastguard Worker y = torch.randn(3, device=device, requires_grad=True) 107*da0073e9SAndroid Build Coastguard Worker optests.aot_autograd_check( 108*da0073e9SAndroid Build Coastguard Worker simple, (y,), {}, dynamic=dynamic, check_gradients=check_gradients 109*da0073e9SAndroid Build Coastguard Worker ) 110*da0073e9SAndroid Build Coastguard Worker 111*da0073e9SAndroid Build Coastguard Worker def no_outputs(x): 112*da0073e9SAndroid Build Coastguard Worker return x.detach() 113*da0073e9SAndroid Build Coastguard Worker 114*da0073e9SAndroid Build Coastguard Worker # Should not raise 115*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3, device=device, requires_grad=True) 116*da0073e9SAndroid Build Coastguard Worker y = torch.randn(3, device=device, requires_grad=False) 117*da0073e9SAndroid Build Coastguard Worker optests.aot_autograd_check( 118*da0073e9SAndroid Build Coastguard Worker no_outputs, (x,), {}, dynamic=dynamic, check_gradients=check_gradients 119*da0073e9SAndroid Build Coastguard Worker ) 120*da0073e9SAndroid Build Coastguard Worker optests.aot_autograd_check( 121*da0073e9SAndroid Build Coastguard Worker no_outputs, (y,), {}, dynamic=dynamic, check_gradients=check_gradients 122*da0073e9SAndroid Build Coastguard Worker ) 123*da0073e9SAndroid Build Coastguard Worker 124*da0073e9SAndroid Build Coastguard Worker def test_incorrect_schema_mutation(self, device): 125*da0073e9SAndroid Build Coastguard Worker lib = self.lib() 126*da0073e9SAndroid Build Coastguard Worker lib.define("foo(Tensor x) -> Tensor") 127*da0073e9SAndroid Build Coastguard Worker op = self.ns().foo.default 128*da0073e9SAndroid Build Coastguard Worker 129*da0073e9SAndroid Build Coastguard Worker class Foo(torch.autograd.Function): 130*da0073e9SAndroid Build Coastguard Worker @staticmethod 131*da0073e9SAndroid Build Coastguard Worker def forward(ctx, x): 132*da0073e9SAndroid Build Coastguard Worker guard = torch._C._AutoDispatchBelowAutograd() 133*da0073e9SAndroid Build Coastguard Worker try: 134*da0073e9SAndroid Build Coastguard Worker return op(x) 135*da0073e9SAndroid Build Coastguard Worker finally: 136*da0073e9SAndroid Build Coastguard Worker del guard 137*da0073e9SAndroid Build Coastguard Worker 138*da0073e9SAndroid Build Coastguard Worker @staticmethod 139*da0073e9SAndroid Build Coastguard Worker def backward(ctx, gx): 140*da0073e9SAndroid Build Coastguard Worker return gx 141*da0073e9SAndroid Build Coastguard Worker 142*da0073e9SAndroid Build Coastguard Worker def foo_impl(x): 143*da0073e9SAndroid Build Coastguard Worker x.sin_() 144*da0073e9SAndroid Build Coastguard Worker return x.clone() 145*da0073e9SAndroid Build Coastguard Worker 146*da0073e9SAndroid Build Coastguard Worker lib.impl("foo", Foo.apply, "Autograd") 147*da0073e9SAndroid Build Coastguard Worker lib.impl("foo", foo_impl, "CPU") 148*da0073e9SAndroid Build Coastguard Worker lib.impl("foo", foo_impl, "CUDA") 149*da0073e9SAndroid Build Coastguard Worker 150*da0073e9SAndroid Build Coastguard Worker x = torch.tensor(3.14159 / 3, requires_grad=True, device=device) 151*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 152*da0073e9SAndroid Build Coastguard Worker optests.OpCheckError, "Argument x is not defined as mutable but was mutated" 153*da0073e9SAndroid Build Coastguard Worker ): 154*da0073e9SAndroid Build Coastguard Worker torch.library.opcheck(op, (x,), {}) 155*da0073e9SAndroid Build Coastguard Worker 156*da0073e9SAndroid Build Coastguard Worker def test_incorrect_schema_view(self, device): 157*da0073e9SAndroid Build Coastguard Worker lib = self.lib() 158*da0073e9SAndroid Build Coastguard Worker lib.define("foo(Tensor x) -> Tensor") 159*da0073e9SAndroid Build Coastguard Worker op = self.ns().foo.default 160*da0073e9SAndroid Build Coastguard Worker 161*da0073e9SAndroid Build Coastguard Worker class Foo(torch.autograd.Function): 162*da0073e9SAndroid Build Coastguard Worker @staticmethod 163*da0073e9SAndroid Build Coastguard Worker def forward(ctx, x): 164*da0073e9SAndroid Build Coastguard Worker # Emulate AutoDispatchBelowADInplaceOrView, which is not bound into python 165*da0073e9SAndroid Build Coastguard Worker with torch._C._AutoDispatchBelowAutograd(): 166*da0073e9SAndroid Build Coastguard Worker with torch._C._ExcludeDispatchKeyGuard( 167*da0073e9SAndroid Build Coastguard Worker torch._C.DispatchKeySet(torch._C.DispatchKey.ADInplaceOrView) 168*da0073e9SAndroid Build Coastguard Worker ): 169*da0073e9SAndroid Build Coastguard Worker return op(x) 170*da0073e9SAndroid Build Coastguard Worker 171*da0073e9SAndroid Build Coastguard Worker @staticmethod 172*da0073e9SAndroid Build Coastguard Worker def backward(ctx, gx): 173*da0073e9SAndroid Build Coastguard Worker return gx 174*da0073e9SAndroid Build Coastguard Worker 175*da0073e9SAndroid Build Coastguard Worker def foo_impl(x): 176*da0073e9SAndroid Build Coastguard Worker return x.view_as(x) 177*da0073e9SAndroid Build Coastguard Worker 178*da0073e9SAndroid Build Coastguard Worker def foo_meta(x): 179*da0073e9SAndroid Build Coastguard Worker return x.view_as(x) 180*da0073e9SAndroid Build Coastguard Worker 181*da0073e9SAndroid Build Coastguard Worker lib.impl("foo", Foo.apply, "Autograd") 182*da0073e9SAndroid Build Coastguard Worker lib.impl("foo", foo_impl, "CPU") 183*da0073e9SAndroid Build Coastguard Worker lib.impl("foo", foo_meta, "Meta") 184*da0073e9SAndroid Build Coastguard Worker 185*da0073e9SAndroid Build Coastguard Worker x = torch.tensor(3.14159 / 3, requires_grad=True) 186*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 187*da0073e9SAndroid Build Coastguard Worker optests.OpCheckError, 188*da0073e9SAndroid Build Coastguard Worker "Argument x is not defined to alias output but was aliasing", 189*da0073e9SAndroid Build Coastguard Worker ): 190*da0073e9SAndroid Build Coastguard Worker torch.library.opcheck(op, (x,), {}) 191*da0073e9SAndroid Build Coastguard Worker 192*da0073e9SAndroid Build Coastguard Worker def test_missing_abstract_impl(self, device): 193*da0073e9SAndroid Build Coastguard Worker lib = self.lib() 194*da0073e9SAndroid Build Coastguard Worker lib.define("foo(Tensor x) -> Tensor") 195*da0073e9SAndroid Build Coastguard Worker op = self.ns().foo.default 196*da0073e9SAndroid Build Coastguard Worker 197*da0073e9SAndroid Build Coastguard Worker class Foo(torch.autograd.Function): 198*da0073e9SAndroid Build Coastguard Worker @staticmethod 199*da0073e9SAndroid Build Coastguard Worker def forward(ctx, x): 200*da0073e9SAndroid Build Coastguard Worker with torch._C._AutoDispatchBelowAutograd(): 201*da0073e9SAndroid Build Coastguard Worker return op(x) 202*da0073e9SAndroid Build Coastguard Worker 203*da0073e9SAndroid Build Coastguard Worker @staticmethod 204*da0073e9SAndroid Build Coastguard Worker def backward(ctx, gx): 205*da0073e9SAndroid Build Coastguard Worker return 2 * gx 206*da0073e9SAndroid Build Coastguard Worker 207*da0073e9SAndroid Build Coastguard Worker def foo_impl(x): 208*da0073e9SAndroid Build Coastguard Worker return torch.tensor(x.cpu().numpy() ** 2, device=x.device) 209*da0073e9SAndroid Build Coastguard Worker 210*da0073e9SAndroid Build Coastguard Worker lib.impl("foo", Foo.apply, "Autograd") 211*da0073e9SAndroid Build Coastguard Worker lib.impl("foo", foo_impl, "CPU") 212*da0073e9SAndroid Build Coastguard Worker lib.impl("foo", foo_impl, "CUDA") 213*da0073e9SAndroid Build Coastguard Worker 214*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([0, 1.0], requires_grad=True) 215*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 216*da0073e9SAndroid Build Coastguard Worker optests.OpCheckError, 217*da0073e9SAndroid Build Coastguard Worker "_test_custom_op.foo.default", 218*da0073e9SAndroid Build Coastguard Worker ): 219*da0073e9SAndroid Build Coastguard Worker torch.library.opcheck(op, (x,), {}) 220*da0073e9SAndroid Build Coastguard Worker 221*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") 222*da0073e9SAndroid Build Coastguard Worker def test_incorrect_abstract_impl(self, device): 223*da0073e9SAndroid Build Coastguard Worker lib = self.lib() 224*da0073e9SAndroid Build Coastguard Worker lib.define("foo(Tensor x) -> Tensor") 225*da0073e9SAndroid Build Coastguard Worker op = self.ns().foo.default 226*da0073e9SAndroid Build Coastguard Worker 227*da0073e9SAndroid Build Coastguard Worker class Foo(torch.autograd.Function): 228*da0073e9SAndroid Build Coastguard Worker @staticmethod 229*da0073e9SAndroid Build Coastguard Worker def forward(ctx, x): 230*da0073e9SAndroid Build Coastguard Worker # Emulate AutoDispatchBelowADInplaceOrView, which is not bound into python 231*da0073e9SAndroid Build Coastguard Worker guard = torch._C._AutoDispatchBelowAutograd() 232*da0073e9SAndroid Build Coastguard Worker guard2 = torch._C.ExcludeDispatchKeyGuard( 233*da0073e9SAndroid Build Coastguard Worker torch._C.DispatchKeySet(torch._C.DispatchKey.ADInplaceOrView) 234*da0073e9SAndroid Build Coastguard Worker ) 235*da0073e9SAndroid Build Coastguard Worker try: 236*da0073e9SAndroid Build Coastguard Worker return op(x) 237*da0073e9SAndroid Build Coastguard Worker finally: 238*da0073e9SAndroid Build Coastguard Worker del guard 239*da0073e9SAndroid Build Coastguard Worker del guard2 240*da0073e9SAndroid Build Coastguard Worker 241*da0073e9SAndroid Build Coastguard Worker @staticmethod 242*da0073e9SAndroid Build Coastguard Worker def backward(ctx, gx): 243*da0073e9SAndroid Build Coastguard Worker return gx 244*da0073e9SAndroid Build Coastguard Worker 245*da0073e9SAndroid Build Coastguard Worker def foo_impl(x): 246*da0073e9SAndroid Build Coastguard Worker return x**2 247*da0073e9SAndroid Build Coastguard Worker 248*da0073e9SAndroid Build Coastguard Worker def foo_meta(x): 249*da0073e9SAndroid Build Coastguard Worker return x.unsqueeze(1) ** 2 250*da0073e9SAndroid Build Coastguard Worker 251*da0073e9SAndroid Build Coastguard Worker lib.impl("foo", Foo.apply, "Autograd") 252*da0073e9SAndroid Build Coastguard Worker lib.impl("foo", foo_impl, "CPU") 253*da0073e9SAndroid Build Coastguard Worker lib.impl("foo", foo_impl, "CUDA") 254*da0073e9SAndroid Build Coastguard Worker lib.impl("foo", foo_meta, "Meta") 255*da0073e9SAndroid Build Coastguard Worker 256*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([0, 1.0], requires_grad=True) 257*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(optests.OpCheckError, "Shapes .* are not equal"): 258*da0073e9SAndroid Build Coastguard Worker torch.library.opcheck(op, (x,), {}) 259*da0073e9SAndroid Build Coastguard Worker 260*da0073e9SAndroid Build Coastguard Worker def test_missing_functionalization(self, device): 261*da0073e9SAndroid Build Coastguard Worker lib = self.lib() 262*da0073e9SAndroid Build Coastguard Worker lib.define("foo(Tensor(a!) x) -> Tensor(a!)") 263*da0073e9SAndroid Build Coastguard Worker op = self.ns().foo.default 264*da0073e9SAndroid Build Coastguard Worker 265*da0073e9SAndroid Build Coastguard Worker class Foo(torch.autograd.Function): 266*da0073e9SAndroid Build Coastguard Worker @staticmethod 267*da0073e9SAndroid Build Coastguard Worker def forward(ctx, x): 268*da0073e9SAndroid Build Coastguard Worker ctx.mark_dirty(x) 269*da0073e9SAndroid Build Coastguard Worker with torch._C._AutoDispatchBelowAutograd(): 270*da0073e9SAndroid Build Coastguard Worker return op(x) 271*da0073e9SAndroid Build Coastguard Worker 272*da0073e9SAndroid Build Coastguard Worker @staticmethod 273*da0073e9SAndroid Build Coastguard Worker def backward(ctx, gx): 274*da0073e9SAndroid Build Coastguard Worker return gx 275*da0073e9SAndroid Build Coastguard Worker 276*da0073e9SAndroid Build Coastguard Worker def foo_impl(x): 277*da0073e9SAndroid Build Coastguard Worker return x.sin_() 278*da0073e9SAndroid Build Coastguard Worker 279*da0073e9SAndroid Build Coastguard Worker def foo_meta(x): 280*da0073e9SAndroid Build Coastguard Worker return x 281*da0073e9SAndroid Build Coastguard Worker 282*da0073e9SAndroid Build Coastguard Worker lib.impl("foo", Foo.apply, "Autograd") 283*da0073e9SAndroid Build Coastguard Worker lib.impl("foo", foo_impl, "CPU") 284*da0073e9SAndroid Build Coastguard Worker lib.impl("foo", foo_impl, "CUDA") 285*da0073e9SAndroid Build Coastguard Worker lib.impl("foo", foo_meta, "Meta") 286*da0073e9SAndroid Build Coastguard Worker 287*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([0, 1.0]) 288*da0073e9SAndroid Build Coastguard Worker y = x.clone() 289*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 290*da0073e9SAndroid Build Coastguard Worker optests.OpCheckError, 291*da0073e9SAndroid Build Coastguard Worker "We only support functionalizing operators whose outputs do not have alias annotations", 292*da0073e9SAndroid Build Coastguard Worker ): 293*da0073e9SAndroid Build Coastguard Worker torch.library.opcheck(op, (y,), {}) 294*da0073e9SAndroid Build Coastguard Worker 295*da0073e9SAndroid Build Coastguard Worker def test_autograd_registered_at_backend(self, device): 296*da0073e9SAndroid Build Coastguard Worker lib = self.lib() 297*da0073e9SAndroid Build Coastguard Worker lib.define("foo(Tensor x) -> Tensor") 298*da0073e9SAndroid Build Coastguard Worker op = self.ns().foo.default 299*da0073e9SAndroid Build Coastguard Worker 300*da0073e9SAndroid Build Coastguard Worker class Foo(torch.autograd.Function): 301*da0073e9SAndroid Build Coastguard Worker @staticmethod 302*da0073e9SAndroid Build Coastguard Worker def forward(ctx, x): 303*da0073e9SAndroid Build Coastguard Worker return x.clone() 304*da0073e9SAndroid Build Coastguard Worker 305*da0073e9SAndroid Build Coastguard Worker @staticmethod 306*da0073e9SAndroid Build Coastguard Worker def backward(ctx, gx): 307*da0073e9SAndroid Build Coastguard Worker return gx * 0.5 308*da0073e9SAndroid Build Coastguard Worker 309*da0073e9SAndroid Build Coastguard Worker lib.impl("foo", Foo.apply, "CPU") 310*da0073e9SAndroid Build Coastguard Worker lib.impl("foo", Foo.apply, "CUDA") 311*da0073e9SAndroid Build Coastguard Worker lib.impl("foo", lambda x: x.clone(), "Meta") 312*da0073e9SAndroid Build Coastguard Worker 313*da0073e9SAndroid Build Coastguard Worker x = torch.randn([], requires_grad=True) 314*da0073e9SAndroid Build Coastguard Worker 315*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 316*da0073e9SAndroid Build Coastguard Worker torch.testing._internal.optests.OpCheckError, 317*da0073e9SAndroid Build Coastguard Worker "does not have an autograd kernel", 318*da0073e9SAndroid Build Coastguard Worker ): 319*da0073e9SAndroid Build Coastguard Worker torch.library.opcheck(op, (x,), {}) 320*da0073e9SAndroid Build Coastguard Worker 321*da0073e9SAndroid Build Coastguard Worker # I'm not sure why this is necessary 322*da0073e9SAndroid Build Coastguard Worker del lib 323*da0073e9SAndroid Build Coastguard Worker 324*da0073e9SAndroid Build Coastguard Worker def test_global_state_mutation(self, device): 325*da0073e9SAndroid Build Coastguard Worker lib = self.lib() 326*da0073e9SAndroid Build Coastguard Worker lib.define("foo(Tensor x) -> Tensor") 327*da0073e9SAndroid Build Coastguard Worker op = self.ns().foo.default 328*da0073e9SAndroid Build Coastguard Worker 329*da0073e9SAndroid Build Coastguard Worker class Foo(torch.autograd.Function): 330*da0073e9SAndroid Build Coastguard Worker invoked = 0 331*da0073e9SAndroid Build Coastguard Worker 332*da0073e9SAndroid Build Coastguard Worker @staticmethod 333*da0073e9SAndroid Build Coastguard Worker def forward(ctx, x): 334*da0073e9SAndroid Build Coastguard Worker Foo.invoked += 1 335*da0073e9SAndroid Build Coastguard Worker return x.clone() * Foo.invoked 336*da0073e9SAndroid Build Coastguard Worker 337*da0073e9SAndroid Build Coastguard Worker @staticmethod 338*da0073e9SAndroid Build Coastguard Worker def backward(ctx, gx): 339*da0073e9SAndroid Build Coastguard Worker return gx 340*da0073e9SAndroid Build Coastguard Worker 341*da0073e9SAndroid Build Coastguard Worker lib.impl("foo", Foo.apply, "CompositeImplicitAutograd") 342*da0073e9SAndroid Build Coastguard Worker 343*da0073e9SAndroid Build Coastguard Worker x = torch.tensor(3.14159 / 3, requires_grad=True) 344*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 345*da0073e9SAndroid Build Coastguard Worker optests.OpCheckError, "eager-mode PyTorch vs AOTAutograd" 346*da0073e9SAndroid Build Coastguard Worker ): 347*da0073e9SAndroid Build Coastguard Worker torch.library.opcheck(op, (x,), {}) 348*da0073e9SAndroid Build Coastguard Worker 349*da0073e9SAndroid Build Coastguard Worker @ops(custom_op_db.custom_op_db, dtypes=OpDTypes.any_one) 350*da0073e9SAndroid Build Coastguard Worker def test_opcheck_opinfo(self, device, dtype, op): 351*da0073e9SAndroid Build Coastguard Worker for sample_input in op.sample_inputs( 352*da0073e9SAndroid Build Coastguard Worker device, dtype, requires_grad=op.supports_autograd 353*da0073e9SAndroid Build Coastguard Worker ): 354*da0073e9SAndroid Build Coastguard Worker args = [sample_input.input] + list(sample_input.args) 355*da0073e9SAndroid Build Coastguard Worker kwargs = sample_input.kwargs 356*da0073e9SAndroid Build Coastguard Worker torch.library.opcheck(op.op, args, kwargs) 357*da0073e9SAndroid Build Coastguard Worker 358*da0073e9SAndroid Build Coastguard Worker def test_opcheck_fails_basic(self, device): 359*da0073e9SAndroid Build Coastguard Worker @custom_op(f"{self.test_ns}::foo") 360*da0073e9SAndroid Build Coastguard Worker def foo(x: torch.Tensor) -> torch.Tensor: ... 361*da0073e9SAndroid Build Coastguard Worker 362*da0073e9SAndroid Build Coastguard Worker @foo.impl(["cpu", "cuda"]) 363*da0073e9SAndroid Build Coastguard Worker def foo_impl(x): 364*da0073e9SAndroid Build Coastguard Worker return x.sum() 365*da0073e9SAndroid Build Coastguard Worker 366*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3, device=device, requires_grad=True) 367*da0073e9SAndroid Build Coastguard Worker # Triggers the CustomOp autograd NYI error 368*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 369*da0073e9SAndroid Build Coastguard Worker optests.OpCheckError, "Autograd has not been implemented for operator" 370*da0073e9SAndroid Build Coastguard Worker ): 371*da0073e9SAndroid Build Coastguard Worker torch.library.opcheck(self.get_op(f"{self.test_ns}::foo"), (x,), {}) 372*da0073e9SAndroid Build Coastguard Worker 373*da0073e9SAndroid Build Coastguard Worker def test_autograd_registration_check_autograd_kernel(self, device): 374*da0073e9SAndroid Build Coastguard Worker lib = self.lib() 375*da0073e9SAndroid Build Coastguard Worker lib.define("foo(Tensor x) -> Tensor") 376*da0073e9SAndroid Build Coastguard Worker op = self.ns().foo.default 377*da0073e9SAndroid Build Coastguard Worker 378*da0073e9SAndroid Build Coastguard Worker class Foo(torch.autograd.Function): 379*da0073e9SAndroid Build Coastguard Worker @staticmethod 380*da0073e9SAndroid Build Coastguard Worker def forward(ctx, x): 381*da0073e9SAndroid Build Coastguard Worker with torch._C._AutoDispatchBelowAutograd(): 382*da0073e9SAndroid Build Coastguard Worker return op(x) 383*da0073e9SAndroid Build Coastguard Worker 384*da0073e9SAndroid Build Coastguard Worker @staticmethod 385*da0073e9SAndroid Build Coastguard Worker def backward(ctx, gx): 386*da0073e9SAndroid Build Coastguard Worker return gx 387*da0073e9SAndroid Build Coastguard Worker 388*da0073e9SAndroid Build Coastguard Worker def foo_impl(x): 389*da0073e9SAndroid Build Coastguard Worker return x.sin() 390*da0073e9SAndroid Build Coastguard Worker 391*da0073e9SAndroid Build Coastguard Worker lib.impl("foo", Foo.apply, "Autograd") 392*da0073e9SAndroid Build Coastguard Worker lib.impl("foo", foo_impl, "CPU") 393*da0073e9SAndroid Build Coastguard Worker lib.impl("foo", foo_impl, "CUDA") 394*da0073e9SAndroid Build Coastguard Worker 395*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3, requires_grad=True, device=device) 396*da0073e9SAndroid Build Coastguard Worker # Should not raise 397*da0073e9SAndroid Build Coastguard Worker optests.autograd_registration_check(op, (x,), {}) 398*da0073e9SAndroid Build Coastguard Worker 399*da0073e9SAndroid Build Coastguard Worker def test_autograd_registration_check_compositeimplicitautograd(self, device): 400*da0073e9SAndroid Build Coastguard Worker lib = self.lib() 401*da0073e9SAndroid Build Coastguard Worker lib.define("foo(Tensor x) -> Tensor") 402*da0073e9SAndroid Build Coastguard Worker op = self.ns().foo.default 403*da0073e9SAndroid Build Coastguard Worker 404*da0073e9SAndroid Build Coastguard Worker def foo_impl(x): 405*da0073e9SAndroid Build Coastguard Worker return x.sin().cos() 406*da0073e9SAndroid Build Coastguard Worker 407*da0073e9SAndroid Build Coastguard Worker lib.impl("foo", foo_impl, "CompositeImplicitAutograd") 408*da0073e9SAndroid Build Coastguard Worker 409*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3, requires_grad=True, device=device) 410*da0073e9SAndroid Build Coastguard Worker # Should not raise 411*da0073e9SAndroid Build Coastguard Worker optests.autograd_registration_check(op, (x,), {}) 412*da0073e9SAndroid Build Coastguard Worker 413*da0073e9SAndroid Build Coastguard Worker def test_autograd_registration_check_incorrect_composite(self, device): 414*da0073e9SAndroid Build Coastguard Worker lib = self.lib() 415*da0073e9SAndroid Build Coastguard Worker lib.define("foo(Tensor x) -> Tensor") 416*da0073e9SAndroid Build Coastguard Worker op = self.ns().foo.default 417*da0073e9SAndroid Build Coastguard Worker 418*da0073e9SAndroid Build Coastguard Worker def foo_impl(x): 419*da0073e9SAndroid Build Coastguard Worker return x.sin().cos() 420*da0073e9SAndroid Build Coastguard Worker 421*da0073e9SAndroid Build Coastguard Worker lib.impl("foo", foo_impl, "CompositeExplicitAutograd") 422*da0073e9SAndroid Build Coastguard Worker 423*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3, requires_grad=True, device=device) 424*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(AssertionError, "incorrectly registered"): 425*da0073e9SAndroid Build Coastguard Worker optests.autograd_registration_check(op, (x,), {}) 426*da0073e9SAndroid Build Coastguard Worker 427*da0073e9SAndroid Build Coastguard Worker def test_autograd_registration_check_incorrect(self, device): 428*da0073e9SAndroid Build Coastguard Worker lib = self.lib() 429*da0073e9SAndroid Build Coastguard Worker lib.define("foo(Tensor x) -> Tensor") 430*da0073e9SAndroid Build Coastguard Worker op = self.ns().foo.default 431*da0073e9SAndroid Build Coastguard Worker 432*da0073e9SAndroid Build Coastguard Worker class Foo(torch.autograd.Function): 433*da0073e9SAndroid Build Coastguard Worker @staticmethod 434*da0073e9SAndroid Build Coastguard Worker def forward(ctx, x): 435*da0073e9SAndroid Build Coastguard Worker return torch.sin(x) 436*da0073e9SAndroid Build Coastguard Worker 437*da0073e9SAndroid Build Coastguard Worker @staticmethod 438*da0073e9SAndroid Build Coastguard Worker def backward(ctx, gx): 439*da0073e9SAndroid Build Coastguard Worker return gx 440*da0073e9SAndroid Build Coastguard Worker 441*da0073e9SAndroid Build Coastguard Worker lib.impl("foo", Foo.apply, "CPU") 442*da0073e9SAndroid Build Coastguard Worker lib.impl("foo", Foo.apply, "CUDA") 443*da0073e9SAndroid Build Coastguard Worker 444*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3, requires_grad=True, device=device) 445*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(AssertionError, "incorrectly registered"): 446*da0073e9SAndroid Build Coastguard Worker optests.autograd_registration_check(op, (x,), {}) 447*da0073e9SAndroid Build Coastguard Worker 448*da0073e9SAndroid Build Coastguard Worker def test_assert_raises_regex(self, device): 449*da0073e9SAndroid Build Coastguard Worker from torch.testing._internal.optests.aot_autograd import assert_raises_regex 450*da0073e9SAndroid Build Coastguard Worker 451*da0073e9SAndroid Build Coastguard Worker with assert_raises_regex(RuntimeError, "c"): 452*da0073e9SAndroid Build Coastguard Worker raise RuntimeError("abcd") 453*da0073e9SAndroid Build Coastguard Worker with assert_raises_regex(RuntimeError, "c.*"): 454*da0073e9SAndroid Build Coastguard Worker raise RuntimeError("abcd") 455*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(AssertionError, "instead got"): 456*da0073e9SAndroid Build Coastguard Worker with assert_raises_regex(RuntimeError, "c.*"): 457*da0073e9SAndroid Build Coastguard Worker raise ValueError("abcd") 458*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(AssertionError, "Expected exception"): 459*da0073e9SAndroid Build Coastguard Worker with assert_raises_regex(RuntimeError, "c.*"): 460*da0073e9SAndroid Build Coastguard Worker pass 461*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(AssertionError, "to match regex"): 462*da0073e9SAndroid Build Coastguard Worker with assert_raises_regex(RuntimeError, "f"): 463*da0073e9SAndroid Build Coastguard Worker raise RuntimeError("abcd") 464*da0073e9SAndroid Build Coastguard Worker 465*da0073e9SAndroid Build Coastguard Worker 466*da0073e9SAndroid Build Coastguard Workerclass TestCustomOp(CustomOpTestCaseBase): 467*da0073e9SAndroid Build Coastguard Worker test_ns = "_test_custom_op" 468*da0073e9SAndroid Build Coastguard Worker 469*da0073e9SAndroid Build Coastguard Worker @requires_compile 470*da0073e9SAndroid Build Coastguard Worker def test_functionalize_error(self): 471*da0073e9SAndroid Build Coastguard Worker with torch.library._scoped_library(self.test_ns, "FRAGMENT") as lib: 472*da0073e9SAndroid Build Coastguard Worker lib.define("foo(Tensor(a!) x) -> Tensor(a!)") 473*da0073e9SAndroid Build Coastguard Worker 474*da0073e9SAndroid Build Coastguard Worker def foo(x): 475*da0073e9SAndroid Build Coastguard Worker return x.sin_() 476*da0073e9SAndroid Build Coastguard Worker 477*da0073e9SAndroid Build Coastguard Worker lib.impl("foo", foo, "CompositeExplicitAutograd") 478*da0073e9SAndroid Build Coastguard Worker foo_op = self.get_op(f"{self.test_ns}::foo") 479*da0073e9SAndroid Build Coastguard Worker 480*da0073e9SAndroid Build Coastguard Worker lib.define("bar(Tensor(a) x) -> Tensor(a)") 481*da0073e9SAndroid Build Coastguard Worker 482*da0073e9SAndroid Build Coastguard Worker def bar(x): 483*da0073e9SAndroid Build Coastguard Worker return x.view(-1) 484*da0073e9SAndroid Build Coastguard Worker 485*da0073e9SAndroid Build Coastguard Worker lib.impl("bar", bar, "CompositeExplicitAutograd") 486*da0073e9SAndroid Build Coastguard Worker bar_op = self.get_op(f"{self.test_ns}::bar") 487*da0073e9SAndroid Build Coastguard Worker 488*da0073e9SAndroid Build Coastguard Worker msg = r".*We only support functionalizing operators whose outputs do not have alias annotations" 489*da0073e9SAndroid Build Coastguard Worker 490*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3) 491*da0073e9SAndroid Build Coastguard Worker 492*da0073e9SAndroid Build Coastguard Worker @torch.compile(backend="aot_eager", fullgraph=True) 493*da0073e9SAndroid Build Coastguard Worker def f(x): 494*da0073e9SAndroid Build Coastguard Worker return foo_op(x) 495*da0073e9SAndroid Build Coastguard Worker 496*da0073e9SAndroid Build Coastguard Worker @torch.compile(backend="aot_eager", fullgraph=True) 497*da0073e9SAndroid Build Coastguard Worker def g(x): 498*da0073e9SAndroid Build Coastguard Worker return bar_op(x) 499*da0073e9SAndroid Build Coastguard Worker 500*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, msg): 501*da0073e9SAndroid Build Coastguard Worker f(x) 502*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, msg): 503*da0073e9SAndroid Build Coastguard Worker g(x) 504*da0073e9SAndroid Build Coastguard Worker 505*da0073e9SAndroid Build Coastguard Worker def test_invalid_schemas(self): 506*da0073e9SAndroid Build Coastguard Worker # function schmea validation goes through torchgen, so this is just a 507*da0073e9SAndroid Build Coastguard Worker # basic test. 508*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(AssertionError, "Invalid function schema: foo"): 509*da0073e9SAndroid Build Coastguard Worker custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo", "(") 510*da0073e9SAndroid Build Coastguard Worker 511*da0073e9SAndroid Build Coastguard Worker def test_invalid_qualname(self): 512*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, "overload"): 513*da0073e9SAndroid Build Coastguard Worker custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo.Tensor", "() -> ()") 514*da0073e9SAndroid Build Coastguard Worker 515*da0073e9SAndroid Build Coastguard Worker def test_name_must_match(self): 516*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, "to have name"): 517*da0073e9SAndroid Build Coastguard Worker 518*da0073e9SAndroid Build Coastguard Worker @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") 519*da0073e9SAndroid Build Coastguard Worker def baz(x: Tensor) -> Tensor: 520*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError 521*da0073e9SAndroid Build Coastguard Worker 522*da0073e9SAndroid Build Coastguard Worker def test_unsupported_schemas(self): 523*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, "only supports functional"): 524*da0073e9SAndroid Build Coastguard Worker custom_ops.custom_op( 525*da0073e9SAndroid Build Coastguard Worker f"{TestCustomOp.test_ns}::foo", "(Tensor(a!) x) -> Tensor(a)" 526*da0073e9SAndroid Build Coastguard Worker )(foo) 527*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, "only supports functional"): 528*da0073e9SAndroid Build Coastguard Worker custom_ops.custom_op( 529*da0073e9SAndroid Build Coastguard Worker f"{TestCustomOp.test_ns}::foo", "(Tensor(a) x) -> Tensor(a)" 530*da0073e9SAndroid Build Coastguard Worker )(foo) 531*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, "only supports functional"): 532*da0073e9SAndroid Build Coastguard Worker custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo", "(Tensor x) -> ()")( 533*da0073e9SAndroid Build Coastguard Worker foo 534*da0073e9SAndroid Build Coastguard Worker ) 535*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, "self"): 536*da0073e9SAndroid Build Coastguard Worker custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo", "(Tensor self) -> ()")( 537*da0073e9SAndroid Build Coastguard Worker foo 538*da0073e9SAndroid Build Coastguard Worker ) 539*da0073e9SAndroid Build Coastguard Worker 540*da0073e9SAndroid Build Coastguard Worker # Tests for the older custom_op API 541*da0073e9SAndroid Build Coastguard Worker def test_schema_matches_signature(self): 542*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, "signature to match"): 543*da0073e9SAndroid Build Coastguard Worker 544*da0073e9SAndroid Build Coastguard Worker @custom_op(f"{TestCustomOp.test_ns}::blah", "(Tensor y) -> Tensor") 545*da0073e9SAndroid Build Coastguard Worker def blah(x): 546*da0073e9SAndroid Build Coastguard Worker pass 547*da0073e9SAndroid Build Coastguard Worker 548*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, "signature to match"): 549*da0073e9SAndroid Build Coastguard Worker 550*da0073e9SAndroid Build Coastguard Worker @custom_op( 551*da0073e9SAndroid Build Coastguard Worker f"{TestCustomOp.test_ns}::blah2", "(Tensor x, *, Tensor y) -> Tensor" 552*da0073e9SAndroid Build Coastguard Worker ) 553*da0073e9SAndroid Build Coastguard Worker def blah2(x, y): 554*da0073e9SAndroid Build Coastguard Worker pass 555*da0073e9SAndroid Build Coastguard Worker 556*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, "signature to match"): 557*da0073e9SAndroid Build Coastguard Worker 558*da0073e9SAndroid Build Coastguard Worker @custom_op( 559*da0073e9SAndroid Build Coastguard Worker f"{TestCustomOp.test_ns}::blah3", 560*da0073e9SAndroid Build Coastguard Worker "(Tensor x, *, Tensor w, Tensor z) -> Tensor", 561*da0073e9SAndroid Build Coastguard Worker ) 562*da0073e9SAndroid Build Coastguard Worker def blah3(x, *, y, z): 563*da0073e9SAndroid Build Coastguard Worker pass 564*da0073e9SAndroid Build Coastguard Worker 565*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, "signature to match"): 566*da0073e9SAndroid Build Coastguard Worker 567*da0073e9SAndroid Build Coastguard Worker @custom_op( 568*da0073e9SAndroid Build Coastguard Worker f"{TestCustomOp.test_ns}::blah4", 569*da0073e9SAndroid Build Coastguard Worker "(Tensor x, *, Tensor z, Tensor y) -> Tensor", 570*da0073e9SAndroid Build Coastguard Worker ) 571*da0073e9SAndroid Build Coastguard Worker def blah4(x, *, y, z): 572*da0073e9SAndroid Build Coastguard Worker pass 573*da0073e9SAndroid Build Coastguard Worker 574*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, "not supported"): 575*da0073e9SAndroid Build Coastguard Worker 576*da0073e9SAndroid Build Coastguard Worker @custom_op(f"{TestCustomOp.test_ns}::blah5", "(Tensor x) -> Tensor") 577*da0073e9SAndroid Build Coastguard Worker def blah5(*args): 578*da0073e9SAndroid Build Coastguard Worker pass 579*da0073e9SAndroid Build Coastguard Worker 580*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, "not supported"): 581*da0073e9SAndroid Build Coastguard Worker 582*da0073e9SAndroid Build Coastguard Worker @custom_op( 583*da0073e9SAndroid Build Coastguard Worker f"{TestCustomOp.test_ns}::blah6", "(*, Tensor z, Tensor y) -> Tensor" 584*da0073e9SAndroid Build Coastguard Worker ) 585*da0073e9SAndroid Build Coastguard Worker def blah6(**kwargs): 586*da0073e9SAndroid Build Coastguard Worker pass 587*da0073e9SAndroid Build Coastguard Worker 588*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, "default arguments"): 589*da0073e9SAndroid Build Coastguard Worker 590*da0073e9SAndroid Build Coastguard Worker @custom_op( 591*da0073e9SAndroid Build Coastguard Worker f"{TestCustomOp.test_ns}::blah7", "(Tensor x, *, Tensor y) -> Tensor" 592*da0073e9SAndroid Build Coastguard Worker ) 593*da0073e9SAndroid Build Coastguard Worker def blah7(x=1, *, y): 594*da0073e9SAndroid Build Coastguard Worker pass 595*da0073e9SAndroid Build Coastguard Worker 596*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, "default arguments"): 597*da0073e9SAndroid Build Coastguard Worker 598*da0073e9SAndroid Build Coastguard Worker @custom_op( 599*da0073e9SAndroid Build Coastguard Worker f"{TestCustomOp.test_ns}::blah8", "(Tensor x, *, Tensor y) -> Tensor" 600*da0073e9SAndroid Build Coastguard Worker ) 601*da0073e9SAndroid Build Coastguard Worker def blah8(x, *, y=1): 602*da0073e9SAndroid Build Coastguard Worker pass 603*da0073e9SAndroid Build Coastguard Worker 604*da0073e9SAndroid Build Coastguard Worker # kwonly-arg works 605*da0073e9SAndroid Build Coastguard Worker @custom_op( 606*da0073e9SAndroid Build Coastguard Worker f"{TestCustomOp.test_ns}::blah9", "(Tensor x, *, Tensor y) -> Tensor" 607*da0073e9SAndroid Build Coastguard Worker ) 608*da0073e9SAndroid Build Coastguard Worker def blah9(x, *, y): 609*da0073e9SAndroid Build Coastguard Worker pass 610*da0073e9SAndroid Build Coastguard Worker 611*da0073e9SAndroid Build Coastguard Worker def test_infer_schema_no_return(self): 612*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 613*da0073e9SAndroid Build Coastguard Worker ValueError, "No return type annotation was provided. Please add one." 614*da0073e9SAndroid Build Coastguard Worker ): 615*da0073e9SAndroid Build Coastguard Worker 616*da0073e9SAndroid Build Coastguard Worker @torch.library.custom_op("mylib::foo", mutates_args={}) 617*da0073e9SAndroid Build Coastguard Worker def foo(x: torch.Tensor, y: int): 618*da0073e9SAndroid Build Coastguard Worker return x * y 619*da0073e9SAndroid Build Coastguard Worker 620*da0073e9SAndroid Build Coastguard Worker def test_infer_schema_supported(self): 621*da0073e9SAndroid Build Coastguard Worker def a(x: Tensor) -> Tensor: 622*da0073e9SAndroid Build Coastguard Worker return torch.empty([]) 623*da0073e9SAndroid Build Coastguard Worker 624*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 625*da0073e9SAndroid Build Coastguard Worker infer_schema(a, mutates_args=()), """(Tensor x) -> Tensor""" 626*da0073e9SAndroid Build Coastguard Worker ) 627*da0073e9SAndroid Build Coastguard Worker 628*da0073e9SAndroid Build Coastguard Worker def kwonly1(x: Tensor, *, y: int, z: float) -> Tensor: 629*da0073e9SAndroid Build Coastguard Worker return torch.empty([]) 630*da0073e9SAndroid Build Coastguard Worker 631*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 632*da0073e9SAndroid Build Coastguard Worker infer_schema(kwonly1, mutates_args=()), 633*da0073e9SAndroid Build Coastguard Worker """(Tensor x, *, SymInt y, float z) -> Tensor""", 634*da0073e9SAndroid Build Coastguard Worker ) 635*da0073e9SAndroid Build Coastguard Worker 636*da0073e9SAndroid Build Coastguard Worker def kwonly2(*, y: Tensor) -> Tensor: 637*da0073e9SAndroid Build Coastguard Worker return torch.empty([]) 638*da0073e9SAndroid Build Coastguard Worker 639*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 640*da0073e9SAndroid Build Coastguard Worker infer_schema(kwonly2, mutates_args=()), """(*, Tensor y) -> Tensor""" 641*da0073e9SAndroid Build Coastguard Worker ) 642*da0073e9SAndroid Build Coastguard Worker 643*da0073e9SAndroid Build Coastguard Worker def b( 644*da0073e9SAndroid Build Coastguard Worker x: Tensor, 645*da0073e9SAndroid Build Coastguard Worker y: int, 646*da0073e9SAndroid Build Coastguard Worker z: bool, 647*da0073e9SAndroid Build Coastguard Worker a: float, 648*da0073e9SAndroid Build Coastguard Worker b: torch.dtype, 649*da0073e9SAndroid Build Coastguard Worker c: torch.device, 650*da0073e9SAndroid Build Coastguard Worker d: torch.types.Number, 651*da0073e9SAndroid Build Coastguard Worker ) -> Tuple[Tensor, int, float, bool]: 652*da0073e9SAndroid Build Coastguard Worker return torch.empty([]), 1, 0.1, True 653*da0073e9SAndroid Build Coastguard Worker 654*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 655*da0073e9SAndroid Build Coastguard Worker infer_schema(b, mutates_args=()), 656*da0073e9SAndroid Build Coastguard Worker """(Tensor x, SymInt y, bool z, float a, ScalarType b, Device c, Scalar d) -> (Tensor, SymInt, float, bool)""", 657*da0073e9SAndroid Build Coastguard Worker ) 658*da0073e9SAndroid Build Coastguard Worker 659*da0073e9SAndroid Build Coastguard Worker def c( 660*da0073e9SAndroid Build Coastguard Worker x: Tensor, 661*da0073e9SAndroid Build Coastguard Worker y: Sequence[Tensor], 662*da0073e9SAndroid Build Coastguard Worker z: Optional[Tensor], 663*da0073e9SAndroid Build Coastguard Worker w: Sequence[Optional[Tensor]], 664*da0073e9SAndroid Build Coastguard Worker ) -> List[Tensor]: 665*da0073e9SAndroid Build Coastguard Worker return [torch.empty([])] 666*da0073e9SAndroid Build Coastguard Worker 667*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 668*da0073e9SAndroid Build Coastguard Worker infer_schema(c, mutates_args=()), 669*da0073e9SAndroid Build Coastguard Worker """(Tensor x, Tensor[] y, Tensor? z, Tensor?[] w) -> Tensor[]""", 670*da0073e9SAndroid Build Coastguard Worker ) 671*da0073e9SAndroid Build Coastguard Worker 672*da0073e9SAndroid Build Coastguard Worker def d(x: Tensor) -> Tuple[List[Tensor], Tensor]: 673*da0073e9SAndroid Build Coastguard Worker return [torch.empty([])], torch.empty([]) 674*da0073e9SAndroid Build Coastguard Worker 675*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 676*da0073e9SAndroid Build Coastguard Worker infer_schema(d, mutates_args=()), """(Tensor x) -> (Tensor[], Tensor)""" 677*da0073e9SAndroid Build Coastguard Worker ) 678*da0073e9SAndroid Build Coastguard Worker 679*da0073e9SAndroid Build Coastguard Worker def e() -> Tensor: 680*da0073e9SAndroid Build Coastguard Worker return torch.empty([]) 681*da0073e9SAndroid Build Coastguard Worker 682*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline(infer_schema(e, mutates_args=()), """() -> Tensor""") 683*da0073e9SAndroid Build Coastguard Worker 684*da0073e9SAndroid Build Coastguard Worker def f(x: Tensor) -> None: 685*da0073e9SAndroid Build Coastguard Worker pass 686*da0073e9SAndroid Build Coastguard Worker 687*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 688*da0073e9SAndroid Build Coastguard Worker infer_schema(f, mutates_args=()), """(Tensor x) -> ()""" 689*da0073e9SAndroid Build Coastguard Worker ) 690*da0073e9SAndroid Build Coastguard Worker 691*da0073e9SAndroid Build Coastguard Worker def g( 692*da0073e9SAndroid Build Coastguard Worker x: Tensor, y: List[Tensor], z: List[Tensor], w: List[Optional[Tensor]] 693*da0073e9SAndroid Build Coastguard Worker ) -> None: 694*da0073e9SAndroid Build Coastguard Worker pass 695*da0073e9SAndroid Build Coastguard Worker 696*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 697*da0073e9SAndroid Build Coastguard Worker infer_schema(g, mutates_args=()), 698*da0073e9SAndroid Build Coastguard Worker """(Tensor x, Tensor[] y, Tensor[] z, Tensor?[] w) -> ()""", 699*da0073e9SAndroid Build Coastguard Worker ) 700*da0073e9SAndroid Build Coastguard Worker 701*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 702*da0073e9SAndroid Build Coastguard Worker infer_schema(g, mutates_args={"x", "w", "z"}), 703*da0073e9SAndroid Build Coastguard Worker """(Tensor(a0!) x, Tensor[] y, Tensor(a2!)[] z, Tensor(a3!)?[] w) -> ()""", 704*da0073e9SAndroid Build Coastguard Worker ) 705*da0073e9SAndroid Build Coastguard Worker 706*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 707*da0073e9SAndroid Build Coastguard Worker infer_schema(g, mutates_args="unknown"), 708*da0073e9SAndroid Build Coastguard Worker """(Tensor(a0!) x, Tensor(a1!)[] y, Tensor(a2!)[] z, Tensor(a3!)?[] w) -> ()""", 709*da0073e9SAndroid Build Coastguard Worker ) 710*da0073e9SAndroid Build Coastguard Worker 711*da0073e9SAndroid Build Coastguard Worker def h( 712*da0073e9SAndroid Build Coastguard Worker x: Tensor, 713*da0073e9SAndroid Build Coastguard Worker a: Optional[int] = None, 714*da0073e9SAndroid Build Coastguard Worker b: float = 3.14, 715*da0073e9SAndroid Build Coastguard Worker c: bool = True, 716*da0073e9SAndroid Build Coastguard Worker d: int = 3, 717*da0073e9SAndroid Build Coastguard Worker e: str = "foo", 718*da0073e9SAndroid Build Coastguard Worker f: torch.dtype = torch.float, 719*da0073e9SAndroid Build Coastguard Worker g: torch.dtype = torch.float32, 720*da0073e9SAndroid Build Coastguard Worker h: torch.dtype = torch.int, 721*da0073e9SAndroid Build Coastguard Worker i: torch.device = torch.device("cpu:0"), 722*da0073e9SAndroid Build Coastguard Worker j: torch.device = "cpu", 723*da0073e9SAndroid Build Coastguard Worker ) -> None: 724*da0073e9SAndroid Build Coastguard Worker pass 725*da0073e9SAndroid Build Coastguard Worker 726*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 727*da0073e9SAndroid Build Coastguard Worker infer_schema(h, mutates_args=()), 728*da0073e9SAndroid Build Coastguard Worker ( 729*da0073e9SAndroid Build Coastguard Worker """(Tensor x, SymInt? a=None, float b=3.14, bool c=True, SymInt d=3, str e="foo", """ 730*da0073e9SAndroid Build Coastguard Worker """ScalarType f=float32, ScalarType g=float32, ScalarType h=int32, Device i="cpu:0", Device j="cpu") -> ()""" 731*da0073e9SAndroid Build Coastguard Worker ), 732*da0073e9SAndroid Build Coastguard Worker ) 733*da0073e9SAndroid Build Coastguard Worker 734*da0073e9SAndroid Build Coastguard Worker def foo_impl(x: torch.Tensor) -> torch.Tensor: 735*da0073e9SAndroid Build Coastguard Worker return x.sin() 736*da0073e9SAndroid Build Coastguard Worker 737*da0073e9SAndroid Build Coastguard Worker schema = torch.library.infer_schema(foo_impl, op_name="myop", mutates_args={}) 738*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline(schema, "myop(Tensor x) -> Tensor") 739*da0073e9SAndroid Build Coastguard Worker 740*da0073e9SAndroid Build Coastguard Worker def test_infer_schema_unsupported(self): 741*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, "varargs"): 742*da0073e9SAndroid Build Coastguard Worker 743*da0073e9SAndroid Build Coastguard Worker def foo(*args): 744*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError 745*da0073e9SAndroid Build Coastguard Worker 746*da0073e9SAndroid Build Coastguard Worker infer_schema(foo, mutates_args=()) 747*da0073e9SAndroid Build Coastguard Worker 748*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, "varkwargs"): 749*da0073e9SAndroid Build Coastguard Worker 750*da0073e9SAndroid Build Coastguard Worker def foo(**kwargs): 751*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError 752*da0073e9SAndroid Build Coastguard Worker 753*da0073e9SAndroid Build Coastguard Worker infer_schema(foo, mutates_args=()) 754*da0073e9SAndroid Build Coastguard Worker 755*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, "must have a type annotation"): 756*da0073e9SAndroid Build Coastguard Worker 757*da0073e9SAndroid Build Coastguard Worker def foo(x): 758*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError 759*da0073e9SAndroid Build Coastguard Worker 760*da0073e9SAndroid Build Coastguard Worker infer_schema(foo, mutates_args=()) 761*da0073e9SAndroid Build Coastguard Worker 762*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, "unsupported"): 763*da0073e9SAndroid Build Coastguard Worker 764*da0073e9SAndroid Build Coastguard Worker def foo(x: Tensor) -> Tuple[Tensor, ...]: 765*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError 766*da0073e9SAndroid Build Coastguard Worker 767*da0073e9SAndroid Build Coastguard Worker infer_schema(foo, mutates_args=()) 768*da0073e9SAndroid Build Coastguard Worker 769*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, "can be mutated"): 770*da0073e9SAndroid Build Coastguard Worker 771*da0073e9SAndroid Build Coastguard Worker def foo(x: Tensor, y: int) -> Tensor: 772*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError 773*da0073e9SAndroid Build Coastguard Worker 774*da0073e9SAndroid Build Coastguard Worker infer_schema(foo, mutates_args={"y"}) 775*da0073e9SAndroid Build Coastguard Worker 776*da0073e9SAndroid Build Coastguard Worker def _generate_examples(self, typ): 777*da0073e9SAndroid Build Coastguard Worker if typ is int: 778*da0073e9SAndroid Build Coastguard Worker return [17] 779*da0073e9SAndroid Build Coastguard Worker if typ is float: 780*da0073e9SAndroid Build Coastguard Worker return [3.14] 781*da0073e9SAndroid Build Coastguard Worker if typ is bool: 782*da0073e9SAndroid Build Coastguard Worker return [True] 783*da0073e9SAndroid Build Coastguard Worker if typ is str: 784*da0073e9SAndroid Build Coastguard Worker return ["foo"] 785*da0073e9SAndroid Build Coastguard Worker if typ is torch.dtype: 786*da0073e9SAndroid Build Coastguard Worker return [torch.float32] 787*da0073e9SAndroid Build Coastguard Worker if typ is torch.device: 788*da0073e9SAndroid Build Coastguard Worker return [torch.device("cpu")] 789*da0073e9SAndroid Build Coastguard Worker if typ == torch.types.Number: 790*da0073e9SAndroid Build Coastguard Worker return [2.718] 791*da0073e9SAndroid Build Coastguard Worker if typ is torch.Tensor: 792*da0073e9SAndroid Build Coastguard Worker return [torch.tensor(3)] 793*da0073e9SAndroid Build Coastguard Worker if typ == Optional[torch.types.Number]: 794*da0073e9SAndroid Build Coastguard Worker return [None, 2.718] 795*da0073e9SAndroid Build Coastguard Worker origin = typing.get_origin(typ) 796*da0073e9SAndroid Build Coastguard Worker if origin is Union: 797*da0073e9SAndroid Build Coastguard Worker args = typing.get_args(typ) 798*da0073e9SAndroid Build Coastguard Worker assert len(args) == 2 and (args[0] is type(None) or args[1] is type(None)) 799*da0073e9SAndroid Build Coastguard Worker elt = args[0] if args[1] is type(None) else args[1] 800*da0073e9SAndroid Build Coastguard Worker return self._generate_examples(elt) + [None] 801*da0073e9SAndroid Build Coastguard Worker if origin is list: 802*da0073e9SAndroid Build Coastguard Worker args = typing.get_args(typ) 803*da0073e9SAndroid Build Coastguard Worker assert len(args) == 1 804*da0073e9SAndroid Build Coastguard Worker elt = args[0] 805*da0073e9SAndroid Build Coastguard Worker return [ 806*da0073e9SAndroid Build Coastguard Worker self._generate_examples(elt), 807*da0073e9SAndroid Build Coastguard Worker self._generate_examples(elt), 808*da0073e9SAndroid Build Coastguard Worker self._generate_examples(elt), 809*da0073e9SAndroid Build Coastguard Worker ] 810*da0073e9SAndroid Build Coastguard Worker if origin is collections.abc.Sequence: 811*da0073e9SAndroid Build Coastguard Worker args = typing.get_args(typ) 812*da0073e9SAndroid Build Coastguard Worker assert len(args) == 1 813*da0073e9SAndroid Build Coastguard Worker examples = self._generate_examples(args[0]) 814*da0073e9SAndroid Build Coastguard Worker return list(itertools.product(examples, examples)) + [] 815*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError( 816*da0073e9SAndroid Build Coastguard Worker f"testrunner cannot generate instanstance of type {typ}" 817*da0073e9SAndroid Build Coastguard Worker ) 818*da0073e9SAndroid Build Coastguard Worker 819*da0073e9SAndroid Build Coastguard Worker def test_supported_return_types_single_return(self): 820*da0073e9SAndroid Build Coastguard Worker for typ in torch._library.infer_schema.SUPPORTED_RETURN_TYPES: 821*da0073e9SAndroid Build Coastguard Worker for example in self._generate_examples(typ): 822*da0073e9SAndroid Build Coastguard Worker try: 823*da0073e9SAndroid Build Coastguard Worker 824*da0073e9SAndroid Build Coastguard Worker @custom_ops.custom_op(f"{self.test_ns}::foo") 825*da0073e9SAndroid Build Coastguard Worker def foo(x: Tensor) -> typ: 826*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError 827*da0073e9SAndroid Build Coastguard Worker 828*da0073e9SAndroid Build Coastguard Worker @custom_ops.impl(f"{self.test_ns}::foo") 829*da0073e9SAndroid Build Coastguard Worker def foo_impl(x: Tensor) -> typ: 830*da0073e9SAndroid Build Coastguard Worker return example 831*da0073e9SAndroid Build Coastguard Worker 832*da0073e9SAndroid Build Coastguard Worker op = self.get_op(f"{self.test_ns}::foo") 833*da0073e9SAndroid Build Coastguard Worker result = op(torch.randn([])) 834*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, example, msg=f"{typ} {example}") 835*da0073e9SAndroid Build Coastguard Worker finally: 836*da0073e9SAndroid Build Coastguard Worker custom_ops._destroy(f"{self.test_ns}::foo") 837*da0073e9SAndroid Build Coastguard Worker 838*da0073e9SAndroid Build Coastguard Worker def test_supported_return_types_multi_return(self): 839*da0073e9SAndroid Build Coastguard Worker for typ in torch._library.infer_schema.SUPPORTED_RETURN_TYPES: 840*da0073e9SAndroid Build Coastguard Worker for example in self._generate_examples(typ): 841*da0073e9SAndroid Build Coastguard Worker try: 842*da0073e9SAndroid Build Coastguard Worker 843*da0073e9SAndroid Build Coastguard Worker @custom_ops.custom_op(f"{self.test_ns}::foo") 844*da0073e9SAndroid Build Coastguard Worker def foo(x: Tensor) -> Tuple[typ, typ]: 845*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError 846*da0073e9SAndroid Build Coastguard Worker 847*da0073e9SAndroid Build Coastguard Worker @custom_ops.impl(f"{self.test_ns}::foo") 848*da0073e9SAndroid Build Coastguard Worker def foo_impl(x: Tensor) -> Tuple[typ, typ]: 849*da0073e9SAndroid Build Coastguard Worker return (example, example) 850*da0073e9SAndroid Build Coastguard Worker 851*da0073e9SAndroid Build Coastguard Worker op = self.get_op(f"{self.test_ns}::foo") 852*da0073e9SAndroid Build Coastguard Worker result = op(torch.randn([])) 853*da0073e9SAndroid Build Coastguard Worker expected = (example, example) 854*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, expected, msg=f"{typ} {example}") 855*da0073e9SAndroid Build Coastguard Worker finally: 856*da0073e9SAndroid Build Coastguard Worker custom_ops._destroy(f"{self.test_ns}::foo") 857*da0073e9SAndroid Build Coastguard Worker 858*da0073e9SAndroid Build Coastguard Worker def test_supported_param_types(self): 859*da0073e9SAndroid Build Coastguard Worker for typ in torch._library.infer_schema.SUPPORTED_PARAM_TYPES: 860*da0073e9SAndroid Build Coastguard Worker 861*da0073e9SAndroid Build Coastguard Worker @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") 862*da0073e9SAndroid Build Coastguard Worker def foo(x: Tensor, y: typ) -> Tensor: 863*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError 864*da0073e9SAndroid Build Coastguard Worker 865*da0073e9SAndroid Build Coastguard Worker yeet = None 866*da0073e9SAndroid Build Coastguard Worker 867*da0073e9SAndroid Build Coastguard Worker @custom_ops.impl(f"{TestCustomOp.test_ns}::foo", device_types=["cpu"]) 868*da0073e9SAndroid Build Coastguard Worker def foo_cpu(x, y): 869*da0073e9SAndroid Build Coastguard Worker nonlocal yeet 870*da0073e9SAndroid Build Coastguard Worker yeet = y 871*da0073e9SAndroid Build Coastguard Worker return x.clone() 872*da0073e9SAndroid Build Coastguard Worker 873*da0073e9SAndroid Build Coastguard Worker try: 874*da0073e9SAndroid Build Coastguard Worker for example in self._generate_examples(typ): 875*da0073e9SAndroid Build Coastguard Worker op = self.get_op(f"{self.test_ns}::foo") 876*da0073e9SAndroid Build Coastguard Worker op(torch.randn([]), example) 877*da0073e9SAndroid Build Coastguard Worker self.assertEqual(yeet, example, msg=f"{typ} {example}") 878*da0073e9SAndroid Build Coastguard Worker yeet = None 879*da0073e9SAndroid Build Coastguard Worker finally: 880*da0073e9SAndroid Build Coastguard Worker custom_ops._destroy(f"{TestCustomOp.test_ns}::foo") 881*da0073e9SAndroid Build Coastguard Worker 882*da0073e9SAndroid Build Coastguard Worker def test_sequences(self): 883*da0073e9SAndroid Build Coastguard Worker # Sequence[int] gets automagically turned into int[] in the schema. 884*da0073e9SAndroid Build Coastguard Worker # This test checks that we actually do support arbitrary sequence types. 885*da0073e9SAndroid Build Coastguard Worker class MySequence(collections.abc.Sequence): 886*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 887*da0073e9SAndroid Build Coastguard Worker self._container = [1, 2, 3] 888*da0073e9SAndroid Build Coastguard Worker 889*da0073e9SAndroid Build Coastguard Worker def __getitem__(self, idx): 890*da0073e9SAndroid Build Coastguard Worker return self._container[idx] 891*da0073e9SAndroid Build Coastguard Worker 892*da0073e9SAndroid Build Coastguard Worker def __len__(self): 893*da0073e9SAndroid Build Coastguard Worker return len(self._container) 894*da0073e9SAndroid Build Coastguard Worker 895*da0073e9SAndroid Build Coastguard Worker @custom_ops.custom_op(f"{self.test_ns}::foo") 896*da0073e9SAndroid Build Coastguard Worker def foo(x: torch.Tensor, sizes: Sequence[int]) -> torch.Tensor: 897*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError 898*da0073e9SAndroid Build Coastguard Worker 899*da0073e9SAndroid Build Coastguard Worker called = 0 900*da0073e9SAndroid Build Coastguard Worker 901*da0073e9SAndroid Build Coastguard Worker @custom_ops.impl(f"{self.test_ns}::foo", device_types="cpu") 902*da0073e9SAndroid Build Coastguard Worker def foo_cpu(x, sizes): 903*da0073e9SAndroid Build Coastguard Worker nonlocal called 904*da0073e9SAndroid Build Coastguard Worker called += 1 905*da0073e9SAndroid Build Coastguard Worker # Dispatcher will normalize the sequence type into a List 906*da0073e9SAndroid Build Coastguard Worker self.assertEqual(sizes, [1, 2, 3]) 907*da0073e9SAndroid Build Coastguard Worker return x.clone() 908*da0073e9SAndroid Build Coastguard Worker 909*da0073e9SAndroid Build Coastguard Worker x = torch.randn([]) 910*da0073e9SAndroid Build Coastguard Worker seq = MySequence() 911*da0073e9SAndroid Build Coastguard Worker op = self.get_op(f"{self.test_ns}::foo") 912*da0073e9SAndroid Build Coastguard Worker op(x, seq) 913*da0073e9SAndroid Build Coastguard Worker self.assertEqual(called, 1) 914*da0073e9SAndroid Build Coastguard Worker 915*da0073e9SAndroid Build Coastguard Worker def test_unsupported_param_types(self): 916*da0073e9SAndroid Build Coastguard Worker # Not comprehensive (it doesn't need to be), just a check that our mechanism works 917*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, "unsupported type"): 918*da0073e9SAndroid Build Coastguard Worker 919*da0073e9SAndroid Build Coastguard Worker @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") 920*da0073e9SAndroid Build Coastguard Worker def foo(x: Tensor, y: List[Optional[int]]) -> Tensor: 921*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError 922*da0073e9SAndroid Build Coastguard Worker 923*da0073e9SAndroid Build Coastguard Worker del foo 924*da0073e9SAndroid Build Coastguard Worker 925*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, "unsupported type"): 926*da0073e9SAndroid Build Coastguard Worker # int[N] in Dispatcher is a bit wild, so we don't try to support it. 927*da0073e9SAndroid Build Coastguard Worker @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") 928*da0073e9SAndroid Build Coastguard Worker def foo(x: Tensor, y: Tuple[int, int]) -> Tensor: 929*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError 930*da0073e9SAndroid Build Coastguard Worker 931*da0073e9SAndroid Build Coastguard Worker del foo 932*da0073e9SAndroid Build Coastguard Worker 933*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, r"For example, typing.List\[int\]"): 934*da0073e9SAndroid Build Coastguard Worker # test that we propose a correct and supported type. 935*da0073e9SAndroid Build Coastguard Worker @torch.library.custom_op(f"{TestCustomOp.test_ns}::foo", mutates_args={}) 936*da0073e9SAndroid Build Coastguard Worker def foo(x: Tensor, y: Tuple[int, int]) -> Tensor: 937*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError 938*da0073e9SAndroid Build Coastguard Worker 939*da0073e9SAndroid Build Coastguard Worker del foo 940*da0073e9SAndroid Build Coastguard Worker 941*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(ValueError) as cm: 942*da0073e9SAndroid Build Coastguard Worker 943*da0073e9SAndroid Build Coastguard Worker @torch.library.custom_op(f"{TestCustomOp.test_ns}::foo", mutates_args={}) 944*da0073e9SAndroid Build Coastguard Worker def foo(x: Tensor, y: Tuple[int, float]) -> Tensor: 945*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError 946*da0073e9SAndroid Build Coastguard Worker 947*da0073e9SAndroid Build Coastguard Worker del foo 948*da0073e9SAndroid Build Coastguard Worker 949*da0073e9SAndroid Build Coastguard Worker self.assertNotIn("example", str(cm.exception), "") 950*da0073e9SAndroid Build Coastguard Worker 951*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, "unsupported type"): 952*da0073e9SAndroid Build Coastguard Worker 953*da0073e9SAndroid Build Coastguard Worker @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") 954*da0073e9SAndroid Build Coastguard Worker def foo(x: Tensor, y: Callable) -> Tensor: 955*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError 956*da0073e9SAndroid Build Coastguard Worker 957*da0073e9SAndroid Build Coastguard Worker del foo 958*da0073e9SAndroid Build Coastguard Worker 959*da0073e9SAndroid Build Coastguard Worker def test_supported_schemas(self): 960*da0073e9SAndroid Build Coastguard Worker # All of these should already be tested by PyTorch codegen 961*da0073e9SAndroid Build Coastguard Worker # (we share the same mechanism), but here's a sanity check. 962*da0073e9SAndroid Build Coastguard Worker schemas = [ 963*da0073e9SAndroid Build Coastguard Worker "(Tensor x) -> Tensor", 964*da0073e9SAndroid Build Coastguard Worker "(Tensor x) -> Tensor y", 965*da0073e9SAndroid Build Coastguard Worker "(Tensor[] x) -> Tensor y", 966*da0073e9SAndroid Build Coastguard Worker "(Tensor x) -> (Tensor, Tensor)", 967*da0073e9SAndroid Build Coastguard Worker "(Tensor x) -> (Tensor y, Tensor z)", 968*da0073e9SAndroid Build Coastguard Worker "(Tensor x) -> (Tensor y, Tensor z)", 969*da0073e9SAndroid Build Coastguard Worker ] 970*da0073e9SAndroid Build Coastguard Worker other_schemas = [ 971*da0073e9SAndroid Build Coastguard Worker "(Tensor x, Tensor w) -> (Tensor y, Tensor z)", 972*da0073e9SAndroid Build Coastguard Worker "(Tensor x, Tensor w) -> (Tensor, Tensor)", 973*da0073e9SAndroid Build Coastguard Worker "(Tensor x, Tensor w) -> Tensor", 974*da0073e9SAndroid Build Coastguard Worker "(Tensor? x, Tensor w) -> Tensor", 975*da0073e9SAndroid Build Coastguard Worker "(Tensor? x, Tensor[] w) -> Tensor", 976*da0073e9SAndroid Build Coastguard Worker "(Tensor x, int[] w) -> Tensor", 977*da0073e9SAndroid Build Coastguard Worker "(Tensor x, SymInt[] w) -> Tensor", 978*da0073e9SAndroid Build Coastguard Worker "(Tensor x, Scalar w) -> Tensor", 979*da0073e9SAndroid Build Coastguard Worker "(Tensor x, float w) -> Tensor", 980*da0073e9SAndroid Build Coastguard Worker "(Tensor x, float? w) -> Tensor", 981*da0073e9SAndroid Build Coastguard Worker "(Tensor x, bool[] w) -> Tensor", 982*da0073e9SAndroid Build Coastguard Worker ] 983*da0073e9SAndroid Build Coastguard Worker 984*da0073e9SAndroid Build Coastguard Worker for schema in schemas: 985*da0073e9SAndroid Build Coastguard Worker custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo", schema) 986*da0073e9SAndroid Build Coastguard Worker custom_ops._destroy(f"{TestCustomOp.test_ns}::foo") 987*da0073e9SAndroid Build Coastguard Worker for schema in other_schemas: 988*da0073e9SAndroid Build Coastguard Worker custom_ops.custom_op(f"{TestCustomOp.test_ns}::bar", schema) 989*da0073e9SAndroid Build Coastguard Worker custom_ops._destroy(f"{TestCustomOp.test_ns}::bar") 990*da0073e9SAndroid Build Coastguard Worker 991*da0073e9SAndroid Build Coastguard Worker def test_reserved_ns(self): 992*da0073e9SAndroid Build Coastguard Worker from torch._custom_op.impl import RESERVED_NS 993*da0073e9SAndroid Build Coastguard Worker 994*da0073e9SAndroid Build Coastguard Worker for ns in RESERVED_NS: 995*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, "is a reserved namespace"): 996*da0073e9SAndroid Build Coastguard Worker custom_ops.custom_op(f"{ns}::foo", "(Tensor x) -> Tensor") 997*da0073e9SAndroid Build Coastguard Worker 998*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, "is a reserved namespace"): 999*da0073e9SAndroid Build Coastguard Worker 1000*da0073e9SAndroid Build Coastguard Worker @custom_ops.custom_op(f"{ns}::foo2") 1001*da0073e9SAndroid Build Coastguard Worker def foo2(x: torch.Tensor) -> torch.Tensor: 1002*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError 1003*da0073e9SAndroid Build Coastguard Worker 1004*da0073e9SAndroid Build Coastguard Worker def test_private_ctor(self): 1005*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "CustomOp constructor is private"): 1006*da0073e9SAndroid Build Coastguard Worker CustomOp(None, None, None, None, None) 1007*da0073e9SAndroid Build Coastguard Worker 1008*da0073e9SAndroid Build Coastguard Worker def test_lifetime(self): 1009*da0073e9SAndroid Build Coastguard Worker @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") 1010*da0073e9SAndroid Build Coastguard Worker def foo(x: torch.Tensor) -> torch.Tensor: 1011*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError 1012*da0073e9SAndroid Build Coastguard Worker 1013*da0073e9SAndroid Build Coastguard Worker custom_op = torch._custom_op.impl.get_op(f"{TestCustomOp.test_ns}::foo") 1014*da0073e9SAndroid Build Coastguard Worker 1015*da0073e9SAndroid Build Coastguard Worker # We can't define an op multiple times, 1016*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "multiple times"): 1017*da0073e9SAndroid Build Coastguard Worker 1018*da0073e9SAndroid Build Coastguard Worker @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") 1019*da0073e9SAndroid Build Coastguard Worker def foo(x: torch.Tensor) -> torch.Tensor: # noqa: F811 1020*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError 1021*da0073e9SAndroid Build Coastguard Worker 1022*da0073e9SAndroid Build Coastguard Worker # Unless we delete the original op. 1023*da0073e9SAndroid Build Coastguard Worker custom_ops._destroy(f"{TestCustomOp.test_ns}::foo") 1024*da0073e9SAndroid Build Coastguard Worker 1025*da0073e9SAndroid Build Coastguard Worker # Smoke test 1026*da0073e9SAndroid Build Coastguard Worker @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") 1027*da0073e9SAndroid Build Coastguard Worker def foo(x: torch.Tensor) -> torch.Tensor: # noqa: F811 1028*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError 1029*da0073e9SAndroid Build Coastguard Worker 1030*da0073e9SAndroid Build Coastguard Worker custom_ops._destroy(f"{TestCustomOp.test_ns}::foo") 1031*da0073e9SAndroid Build Coastguard Worker 1032*da0073e9SAndroid Build Coastguard Worker def test_autograd_notimplemented(self): 1033*da0073e9SAndroid Build Coastguard Worker @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") 1034*da0073e9SAndroid Build Coastguard Worker def foo(x: torch.Tensor) -> torch.Tensor: # noqa: F811 1035*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError 1036*da0073e9SAndroid Build Coastguard Worker 1037*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3, requires_grad=True) 1038*da0073e9SAndroid Build Coastguard Worker op = self.get_op(f"{self.test_ns}::foo") 1039*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Autograd has not been implemented"): 1040*da0073e9SAndroid Build Coastguard Worker op(x) 1041*da0073e9SAndroid Build Coastguard Worker custom_ops._destroy(f"{TestCustomOp.test_ns}::foo") 1042*da0073e9SAndroid Build Coastguard Worker del foo 1043*da0073e9SAndroid Build Coastguard Worker 1044*da0073e9SAndroid Build Coastguard Worker @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") 1045*da0073e9SAndroid Build Coastguard Worker def foo(x: Sequence[torch.Tensor]) -> torch.Tensor: 1046*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError 1047*da0073e9SAndroid Build Coastguard Worker 1048*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3, requires_grad=True) 1049*da0073e9SAndroid Build Coastguard Worker y = torch.randn(3) 1050*da0073e9SAndroid Build Coastguard Worker op = self.get_op(f"{self.test_ns}::foo") 1051*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Autograd has not been implemented"): 1052*da0073e9SAndroid Build Coastguard Worker op([y, x]) 1053*da0073e9SAndroid Build Coastguard Worker custom_ops._destroy(f"{TestCustomOp.test_ns}::foo") 1054*da0073e9SAndroid Build Coastguard Worker del foo 1055*da0073e9SAndroid Build Coastguard Worker 1056*da0073e9SAndroid Build Coastguard Worker @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") 1057*da0073e9SAndroid Build Coastguard Worker def foo(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 1058*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError 1059*da0073e9SAndroid Build Coastguard Worker 1060*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3, requires_grad=True) 1061*da0073e9SAndroid Build Coastguard Worker y = torch.randn(3) 1062*da0073e9SAndroid Build Coastguard Worker op = self.get_op(f"{self.test_ns}::foo") 1063*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Autograd has not been implemented"): 1064*da0073e9SAndroid Build Coastguard Worker op(y, x) 1065*da0073e9SAndroid Build Coastguard Worker custom_ops._destroy(f"{TestCustomOp.test_ns}::foo") 1066*da0073e9SAndroid Build Coastguard Worker 1067*da0073e9SAndroid Build Coastguard Worker def test_autograd_notimplemented_gradmode(self): 1068*da0073e9SAndroid Build Coastguard Worker @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") 1069*da0073e9SAndroid Build Coastguard Worker def foo(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 1070*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError 1071*da0073e9SAndroid Build Coastguard Worker 1072*da0073e9SAndroid Build Coastguard Worker @custom_ops.impl(f"{TestCustomOp.test_ns}::foo") 1073*da0073e9SAndroid Build Coastguard Worker def foo_impl(x, y): 1074*da0073e9SAndroid Build Coastguard Worker return x * y 1075*da0073e9SAndroid Build Coastguard Worker 1076*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3, requires_grad=True) 1077*da0073e9SAndroid Build Coastguard Worker y = torch.randn(3) 1078*da0073e9SAndroid Build Coastguard Worker op = self.get_op(f"{self.test_ns}::foo") 1079*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 1080*da0073e9SAndroid Build Coastguard Worker # Shouldn't raise, because we are in no_grad 1081*da0073e9SAndroid Build Coastguard Worker op(y, x) 1082*da0073e9SAndroid Build Coastguard Worker 1083*da0073e9SAndroid Build Coastguard Worker def test_impl_cpu(self): 1084*da0073e9SAndroid Build Coastguard Worker @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") 1085*da0073e9SAndroid Build Coastguard Worker def foo(x: torch.Tensor) -> torch.Tensor: 1086*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError 1087*da0073e9SAndroid Build Coastguard Worker 1088*da0073e9SAndroid Build Coastguard Worker @custom_ops.impl(f"{TestCustomOp.test_ns}::foo", device_types="cpu") 1089*da0073e9SAndroid Build Coastguard Worker def foo_cpu(x): 1090*da0073e9SAndroid Build Coastguard Worker return x.sin() 1091*da0073e9SAndroid Build Coastguard Worker 1092*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3) 1093*da0073e9SAndroid Build Coastguard Worker op = self.get_op(f"{self.test_ns}::foo") 1094*da0073e9SAndroid Build Coastguard Worker result = op(x) 1095*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, foo_cpu(x)) 1096*da0073e9SAndroid Build Coastguard Worker 1097*da0073e9SAndroid Build Coastguard Worker def test_impl_invalid_devices(self): 1098*da0073e9SAndroid Build Coastguard Worker @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") 1099*da0073e9SAndroid Build Coastguard Worker def foo(x: torch.Tensor) -> torch.Tensor: 1100*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError 1101*da0073e9SAndroid Build Coastguard Worker 1102*da0073e9SAndroid Build Coastguard Worker def foo_impl(x): 1103*da0073e9SAndroid Build Coastguard Worker return x.sin() 1104*da0073e9SAndroid Build Coastguard Worker 1105*da0073e9SAndroid Build Coastguard Worker from torch._custom_op.impl import SUPPORTED_DEVICE_TYPE_TO_KEY 1106*da0073e9SAndroid Build Coastguard Worker 1107*da0073e9SAndroid Build Coastguard Worker for device_type in SUPPORTED_DEVICE_TYPE_TO_KEY.keys(): 1108*da0073e9SAndroid Build Coastguard Worker # Smoke test: should not raise error 1109*da0073e9SAndroid Build Coastguard Worker custom_ops.impl(f"{TestCustomOp.test_ns}::foo", device_types=device_type)( 1110*da0073e9SAndroid Build Coastguard Worker foo_impl 1111*da0073e9SAndroid Build Coastguard Worker ) 1112*da0073e9SAndroid Build Coastguard Worker 1113*da0073e9SAndroid Build Coastguard Worker # Not supported by this API: we can either support them in the future 1114*da0073e9SAndroid Build Coastguard Worker # or provide some other CustomOp.def_* function. This depends on how 1115*da0073e9SAndroid Build Coastguard Worker # common the use cases are. 1116*da0073e9SAndroid Build Coastguard Worker for invalid_type in ["hip", "xla", "mkldnn", ["cpu", "hip"]]: 1117*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, "we only support device_type"): 1118*da0073e9SAndroid Build Coastguard Worker custom_ops.impl( 1119*da0073e9SAndroid Build Coastguard Worker f"{TestCustomOp.test_ns}::foo", device_types=invalid_type 1120*da0073e9SAndroid Build Coastguard Worker )(foo_impl) 1121*da0073e9SAndroid Build Coastguard Worker 1122*da0073e9SAndroid Build Coastguard Worker def test_backward_partially_registered(self): 1123*da0073e9SAndroid Build Coastguard Worker @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") 1124*da0073e9SAndroid Build Coastguard Worker def foo(x: torch.Tensor) -> torch.Tensor: 1125*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError 1126*da0073e9SAndroid Build Coastguard Worker 1127*da0073e9SAndroid Build Coastguard Worker @custom_ops.impl(f"{TestCustomOp.test_ns}::foo") 1128*da0073e9SAndroid Build Coastguard Worker def foo_impl(x): 1129*da0073e9SAndroid Build Coastguard Worker return x.sin() 1130*da0073e9SAndroid Build Coastguard Worker 1131*da0073e9SAndroid Build Coastguard Worker @custom_ops.impl_backward(f"{TestCustomOp.test_ns}::foo") 1132*da0073e9SAndroid Build Coastguard Worker def foo_backward(ctx, saved, grad): 1133*da0073e9SAndroid Build Coastguard Worker return grad * saved.cos() 1134*da0073e9SAndroid Build Coastguard Worker 1135*da0073e9SAndroid Build Coastguard Worker x = torch.randn([], requires_grad=True) 1136*da0073e9SAndroid Build Coastguard Worker op = self.get_op(f"{self.test_ns}::foo") 1137*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 1138*da0073e9SAndroid Build Coastguard Worker RuntimeError, "unable to find a 'save_for_backward'" 1139*da0073e9SAndroid Build Coastguard Worker ): 1140*da0073e9SAndroid Build Coastguard Worker y = op(x) 1141*da0073e9SAndroid Build Coastguard Worker y.backward() 1142*da0073e9SAndroid Build Coastguard Worker 1143*da0073e9SAndroid Build Coastguard Worker def test_save_for_backward_inputs_are_namedtuple(self): 1144*da0073e9SAndroid Build Coastguard Worker @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") 1145*da0073e9SAndroid Build Coastguard Worker def foo(x: torch.Tensor) -> torch.Tensor: 1146*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError 1147*da0073e9SAndroid Build Coastguard Worker 1148*da0073e9SAndroid Build Coastguard Worker @custom_ops.impl(f"{TestCustomOp.test_ns}::foo") 1149*da0073e9SAndroid Build Coastguard Worker def foo_impl(x): 1150*da0073e9SAndroid Build Coastguard Worker return x.sin() 1151*da0073e9SAndroid Build Coastguard Worker 1152*da0073e9SAndroid Build Coastguard Worker hit = 0 1153*da0073e9SAndroid Build Coastguard Worker 1154*da0073e9SAndroid Build Coastguard Worker @custom_ops.impl_save_for_backward(f"{TestCustomOp.test_ns}::foo") 1155*da0073e9SAndroid Build Coastguard Worker def foo_save_for_backward(inputs, output): 1156*da0073e9SAndroid Build Coastguard Worker nonlocal hit 1157*da0073e9SAndroid Build Coastguard Worker hit += 1 1158*da0073e9SAndroid Build Coastguard Worker self.assertTrue(isinstance(inputs, tuple)) 1159*da0073e9SAndroid Build Coastguard Worker self.assertEqual(list(inputs._asdict().keys()), ["x"]) 1160*da0073e9SAndroid Build Coastguard Worker return inputs.x 1161*da0073e9SAndroid Build Coastguard Worker 1162*da0073e9SAndroid Build Coastguard Worker @custom_ops.impl_backward(f"{TestCustomOp.test_ns}::foo") 1163*da0073e9SAndroid Build Coastguard Worker def foo_backward(ctx, saved, grad): 1164*da0073e9SAndroid Build Coastguard Worker return {"x": grad * saved.cos()} 1165*da0073e9SAndroid Build Coastguard Worker 1166*da0073e9SAndroid Build Coastguard Worker x = torch.randn([], requires_grad=True) 1167*da0073e9SAndroid Build Coastguard Worker op = self.get_op(f"{self.test_ns}::foo") 1168*da0073e9SAndroid Build Coastguard Worker y = op(x) 1169*da0073e9SAndroid Build Coastguard Worker self.assertEqual(hit, 1) 1170*da0073e9SAndroid Build Coastguard Worker y.backward() 1171*da0073e9SAndroid Build Coastguard Worker self.assertEqual(hit, 1) 1172*da0073e9SAndroid Build Coastguard Worker 1173*da0073e9SAndroid Build Coastguard Worker def test_backward_returns_dict(self): 1174*da0073e9SAndroid Build Coastguard Worker @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") 1175*da0073e9SAndroid Build Coastguard Worker def foo(x: torch.Tensor) -> torch.Tensor: 1176*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError 1177*da0073e9SAndroid Build Coastguard Worker 1178*da0073e9SAndroid Build Coastguard Worker @custom_ops.impl(f"{TestCustomOp.test_ns}::foo") 1179*da0073e9SAndroid Build Coastguard Worker def foo_impl(x): 1180*da0073e9SAndroid Build Coastguard Worker return x.sin() 1181*da0073e9SAndroid Build Coastguard Worker 1182*da0073e9SAndroid Build Coastguard Worker @custom_ops.impl_save_for_backward(f"{TestCustomOp.test_ns}::foo") 1183*da0073e9SAndroid Build Coastguard Worker def foo_save_for_backward(inputs, output): 1184*da0073e9SAndroid Build Coastguard Worker return inputs.x 1185*da0073e9SAndroid Build Coastguard Worker 1186*da0073e9SAndroid Build Coastguard Worker @custom_ops.impl_backward(f"{TestCustomOp.test_ns}::foo") 1187*da0073e9SAndroid Build Coastguard Worker def foo_backward(ctx, saved, grad): 1188*da0073e9SAndroid Build Coastguard Worker return grad * saved.cos() 1189*da0073e9SAndroid Build Coastguard Worker 1190*da0073e9SAndroid Build Coastguard Worker x = torch.randn([], requires_grad=True) 1191*da0073e9SAndroid Build Coastguard Worker op = self.get_op(f"{self.test_ns}::foo") 1192*da0073e9SAndroid Build Coastguard Worker y = op(x) 1193*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "to be a dict"): 1194*da0073e9SAndroid Build Coastguard Worker y.backward() 1195*da0073e9SAndroid Build Coastguard Worker 1196*da0073e9SAndroid Build Coastguard Worker def test_backward_dict_invalid_keys(self): 1197*da0073e9SAndroid Build Coastguard Worker @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") 1198*da0073e9SAndroid Build Coastguard Worker def foo(x: torch.Tensor) -> torch.Tensor: 1199*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError 1200*da0073e9SAndroid Build Coastguard Worker 1201*da0073e9SAndroid Build Coastguard Worker @custom_ops.impl(f"{TestCustomOp.test_ns}::foo") 1202*da0073e9SAndroid Build Coastguard Worker def foo_impl(x): 1203*da0073e9SAndroid Build Coastguard Worker return x.sin() 1204*da0073e9SAndroid Build Coastguard Worker 1205*da0073e9SAndroid Build Coastguard Worker @custom_ops.impl_save_for_backward(f"{TestCustomOp.test_ns}::foo") 1206*da0073e9SAndroid Build Coastguard Worker def foo_save_for_backward(inputs, output): 1207*da0073e9SAndroid Build Coastguard Worker return inputs.x 1208*da0073e9SAndroid Build Coastguard Worker 1209*da0073e9SAndroid Build Coastguard Worker @custom_ops.impl_backward(f"{TestCustomOp.test_ns}::foo") 1210*da0073e9SAndroid Build Coastguard Worker def foo_backward(ctx, saved, grad): 1211*da0073e9SAndroid Build Coastguard Worker return {"x": grad * saved.cos(), "y": None} 1212*da0073e9SAndroid Build Coastguard Worker 1213*da0073e9SAndroid Build Coastguard Worker x = torch.randn([], requires_grad=True) 1214*da0073e9SAndroid Build Coastguard Worker op = self.get_op(f"{self.test_ns}::foo") 1215*da0073e9SAndroid Build Coastguard Worker y = op(x) 1216*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "to have keys {'x'}"): 1217*da0073e9SAndroid Build Coastguard Worker y.backward() 1218*da0073e9SAndroid Build Coastguard Worker 1219*da0073e9SAndroid Build Coastguard Worker def test_backward_dict_grad_for_nontensor(self): 1220*da0073e9SAndroid Build Coastguard Worker @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") 1221*da0073e9SAndroid Build Coastguard Worker def foo(x: torch.Tensor, dim: int) -> torch.Tensor: 1222*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError 1223*da0073e9SAndroid Build Coastguard Worker 1224*da0073e9SAndroid Build Coastguard Worker @custom_ops.impl(f"{TestCustomOp.test_ns}::foo") 1225*da0073e9SAndroid Build Coastguard Worker def foo_impl(x, dim): 1226*da0073e9SAndroid Build Coastguard Worker return x.sin() 1227*da0073e9SAndroid Build Coastguard Worker 1228*da0073e9SAndroid Build Coastguard Worker @custom_ops.impl_save_for_backward(f"{TestCustomOp.test_ns}::foo") 1229*da0073e9SAndroid Build Coastguard Worker def foo_save_for_backward(inputs, output): 1230*da0073e9SAndroid Build Coastguard Worker return inputs.x 1231*da0073e9SAndroid Build Coastguard Worker 1232*da0073e9SAndroid Build Coastguard Worker @custom_ops.impl_backward(f"{TestCustomOp.test_ns}::foo") 1233*da0073e9SAndroid Build Coastguard Worker def foo_backward(ctx, saved, grad): 1234*da0073e9SAndroid Build Coastguard Worker return {"x": grad * saved.cos(), "dim": None} 1235*da0073e9SAndroid Build Coastguard Worker 1236*da0073e9SAndroid Build Coastguard Worker x = torch.randn([], requires_grad=True) 1237*da0073e9SAndroid Build Coastguard Worker op = self.get_op(f"{self.test_ns}::foo") 1238*da0073e9SAndroid Build Coastguard Worker y = op(x, 32) 1239*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "non-Tensor-like types"): 1240*da0073e9SAndroid Build Coastguard Worker y.backward() 1241*da0073e9SAndroid Build Coastguard Worker 1242*da0073e9SAndroid Build Coastguard Worker def test_backward_dict_requires_keys_for_input_tensors(self): 1243*da0073e9SAndroid Build Coastguard Worker @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") 1244*da0073e9SAndroid Build Coastguard Worker def foo(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 1245*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError 1246*da0073e9SAndroid Build Coastguard Worker 1247*da0073e9SAndroid Build Coastguard Worker @custom_ops.impl(f"{TestCustomOp.test_ns}::foo") 1248*da0073e9SAndroid Build Coastguard Worker def foo_impl(x, y): 1249*da0073e9SAndroid Build Coastguard Worker return x.sin() 1250*da0073e9SAndroid Build Coastguard Worker 1251*da0073e9SAndroid Build Coastguard Worker @custom_ops.impl_save_for_backward(f"{TestCustomOp.test_ns}::foo") 1252*da0073e9SAndroid Build Coastguard Worker def foo_save_for_backward(inputs, output): 1253*da0073e9SAndroid Build Coastguard Worker return inputs.x 1254*da0073e9SAndroid Build Coastguard Worker 1255*da0073e9SAndroid Build Coastguard Worker @custom_ops.impl_backward(f"{TestCustomOp.test_ns}::foo") 1256*da0073e9SAndroid Build Coastguard Worker def foo_backward(ctx, saved, grad): 1257*da0073e9SAndroid Build Coastguard Worker return {"x": grad * saved.cos()} 1258*da0073e9SAndroid Build Coastguard Worker 1259*da0073e9SAndroid Build Coastguard Worker x = torch.randn([], requires_grad=True) 1260*da0073e9SAndroid Build Coastguard Worker op = self.get_op(f"{self.test_ns}::foo") 1261*da0073e9SAndroid Build Coastguard Worker y = op(x, x) 1262*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, r"to have keys {.*'y'.*}"): 1263*da0073e9SAndroid Build Coastguard Worker y.backward() 1264*da0073e9SAndroid Build Coastguard Worker 1265*da0073e9SAndroid Build Coastguard Worker def test_backward_dict_requires_keys_for_input_optional_tensors(self): 1266*da0073e9SAndroid Build Coastguard Worker @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") 1267*da0073e9SAndroid Build Coastguard Worker def foo(x: torch.Tensor, y: Optional[torch.Tensor]) -> torch.Tensor: 1268*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError 1269*da0073e9SAndroid Build Coastguard Worker 1270*da0073e9SAndroid Build Coastguard Worker @custom_ops.impl(f"{TestCustomOp.test_ns}::foo") 1271*da0073e9SAndroid Build Coastguard Worker def foo_impl(x, y): 1272*da0073e9SAndroid Build Coastguard Worker return x.sin() 1273*da0073e9SAndroid Build Coastguard Worker 1274*da0073e9SAndroid Build Coastguard Worker @custom_ops.impl_save_for_backward(f"{TestCustomOp.test_ns}::foo") 1275*da0073e9SAndroid Build Coastguard Worker def foo_save_for_backward(inputs, output): 1276*da0073e9SAndroid Build Coastguard Worker return inputs.x 1277*da0073e9SAndroid Build Coastguard Worker 1278*da0073e9SAndroid Build Coastguard Worker @custom_ops.impl_backward(f"{TestCustomOp.test_ns}::foo") 1279*da0073e9SAndroid Build Coastguard Worker def foo_backward(ctx, saved, grad): 1280*da0073e9SAndroid Build Coastguard Worker return {"x": grad * saved.cos()} 1281*da0073e9SAndroid Build Coastguard Worker 1282*da0073e9SAndroid Build Coastguard Worker x = torch.randn([], requires_grad=True) 1283*da0073e9SAndroid Build Coastguard Worker op = self.get_op(f"{self.test_ns}::foo") 1284*da0073e9SAndroid Build Coastguard Worker y = op(x, None) 1285*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, r"to have keys {.*'y'.*}"): 1286*da0073e9SAndroid Build Coastguard Worker y.backward() 1287*da0073e9SAndroid Build Coastguard Worker 1288*da0073e9SAndroid Build Coastguard Worker def test_backward_grads_are_tensor_or_none(self): 1289*da0073e9SAndroid Build Coastguard Worker @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") 1290*da0073e9SAndroid Build Coastguard Worker def foo(x: torch.Tensor) -> torch.Tensor: 1291*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError 1292*da0073e9SAndroid Build Coastguard Worker 1293*da0073e9SAndroid Build Coastguard Worker @custom_ops.impl(f"{TestCustomOp.test_ns}::foo") 1294*da0073e9SAndroid Build Coastguard Worker def foo_impl(x): 1295*da0073e9SAndroid Build Coastguard Worker return x.sin() 1296*da0073e9SAndroid Build Coastguard Worker 1297*da0073e9SAndroid Build Coastguard Worker @custom_ops.impl_save_for_backward(f"{TestCustomOp.test_ns}::foo") 1298*da0073e9SAndroid Build Coastguard Worker def foo_save_for_backward(inputs, output): 1299*da0073e9SAndroid Build Coastguard Worker return inputs.x 1300*da0073e9SAndroid Build Coastguard Worker 1301*da0073e9SAndroid Build Coastguard Worker @custom_ops.impl_backward(f"{TestCustomOp.test_ns}::foo") 1302*da0073e9SAndroid Build Coastguard Worker def foo_backward(ctx, saved, grad): 1303*da0073e9SAndroid Build Coastguard Worker return {"x": (grad * saved.cos(),)} 1304*da0073e9SAndroid Build Coastguard Worker 1305*da0073e9SAndroid Build Coastguard Worker x = torch.randn([], requires_grad=True) 1306*da0073e9SAndroid Build Coastguard Worker op = self.get_op(f"{self.test_ns}::foo") 1307*da0073e9SAndroid Build Coastguard Worker y = op(x) 1308*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "either None or a Tensor"): 1309*da0073e9SAndroid Build Coastguard Worker y.backward() 1310*da0073e9SAndroid Build Coastguard Worker 1311*da0073e9SAndroid Build Coastguard Worker def test_backward_tensorlist_input_requires_list_grads_with_same_numel(self): 1312*da0073e9SAndroid Build Coastguard Worker @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") 1313*da0073e9SAndroid Build Coastguard Worker def foo(xs: Sequence[torch.Tensor]) -> torch.Tensor: 1314*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError 1315*da0073e9SAndroid Build Coastguard Worker 1316*da0073e9SAndroid Build Coastguard Worker @custom_ops.impl(f"{TestCustomOp.test_ns}::foo") 1317*da0073e9SAndroid Build Coastguard Worker def foo_impl(xs): 1318*da0073e9SAndroid Build Coastguard Worker return xs[0].sin() 1319*da0073e9SAndroid Build Coastguard Worker 1320*da0073e9SAndroid Build Coastguard Worker @custom_ops.impl_save_for_backward(f"{TestCustomOp.test_ns}::foo") 1321*da0073e9SAndroid Build Coastguard Worker def foo_save_for_backward(inputs, output): 1322*da0073e9SAndroid Build Coastguard Worker return inputs.xs[0] 1323*da0073e9SAndroid Build Coastguard Worker 1324*da0073e9SAndroid Build Coastguard Worker @custom_ops.impl_backward(f"{TestCustomOp.test_ns}::foo") 1325*da0073e9SAndroid Build Coastguard Worker def foo_backward(ctx, saved, grad): 1326*da0073e9SAndroid Build Coastguard Worker return {"xs": [grad * saved.cos(), None]} 1327*da0073e9SAndroid Build Coastguard Worker 1328*da0073e9SAndroid Build Coastguard Worker xs = [torch.randn([], requires_grad=True) for _ in range(3)] 1329*da0073e9SAndroid Build Coastguard Worker op = self.get_op(f"{self.test_ns}::foo") 1330*da0073e9SAndroid Build Coastguard Worker y = op(xs) 1331*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "3 gradients but got 2"): 1332*da0073e9SAndroid Build Coastguard Worker y.backward() 1333*da0073e9SAndroid Build Coastguard Worker 1334*da0073e9SAndroid Build Coastguard Worker def test_backward_tensorlist_input_requires_list_grads_none_or_Tensor(self): 1335*da0073e9SAndroid Build Coastguard Worker @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") 1336*da0073e9SAndroid Build Coastguard Worker def foo(xs: Sequence[torch.Tensor]) -> torch.Tensor: 1337*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError 1338*da0073e9SAndroid Build Coastguard Worker 1339*da0073e9SAndroid Build Coastguard Worker @custom_ops.impl(f"{TestCustomOp.test_ns}::foo") 1340*da0073e9SAndroid Build Coastguard Worker def foo_impl(xs): 1341*da0073e9SAndroid Build Coastguard Worker return xs[0].sin() 1342*da0073e9SAndroid Build Coastguard Worker 1343*da0073e9SAndroid Build Coastguard Worker @custom_ops.impl_save_for_backward(f"{TestCustomOp.test_ns}::foo") 1344*da0073e9SAndroid Build Coastguard Worker def foo_save_for_backward(inputs, output): 1345*da0073e9SAndroid Build Coastguard Worker return inputs.xs[0] 1346*da0073e9SAndroid Build Coastguard Worker 1347*da0073e9SAndroid Build Coastguard Worker @custom_ops.impl_backward(f"{TestCustomOp.test_ns}::foo") 1348*da0073e9SAndroid Build Coastguard Worker def foo_backward(ctx, saved, grad): 1349*da0073e9SAndroid Build Coastguard Worker return {"xs": [grad * saved.cos(), None, (None,)]} 1350*da0073e9SAndroid Build Coastguard Worker 1351*da0073e9SAndroid Build Coastguard Worker xs = [torch.randn([], requires_grad=True) for _ in range(3)] 1352*da0073e9SAndroid Build Coastguard Worker op = self.get_op(f"{self.test_ns}::foo") 1353*da0073e9SAndroid Build Coastguard Worker y = op(xs) 1354*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "None or Tensor"): 1355*da0073e9SAndroid Build Coastguard Worker y.backward() 1356*da0073e9SAndroid Build Coastguard Worker 1357*da0073e9SAndroid Build Coastguard Worker def test_backward_tensorlist_input_requires_list_grads(self): 1358*da0073e9SAndroid Build Coastguard Worker @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") 1359*da0073e9SAndroid Build Coastguard Worker def foo(xs: Sequence[torch.Tensor]) -> torch.Tensor: 1360*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError 1361*da0073e9SAndroid Build Coastguard Worker 1362*da0073e9SAndroid Build Coastguard Worker @custom_ops.impl(f"{TestCustomOp.test_ns}::foo") 1363*da0073e9SAndroid Build Coastguard Worker def foo_impl(xs): 1364*da0073e9SAndroid Build Coastguard Worker return xs[0].sin() 1365*da0073e9SAndroid Build Coastguard Worker 1366*da0073e9SAndroid Build Coastguard Worker @custom_ops.impl_save_for_backward(f"{TestCustomOp.test_ns}::foo") 1367*da0073e9SAndroid Build Coastguard Worker def foo_save_for_backward(inputs, output): 1368*da0073e9SAndroid Build Coastguard Worker return inputs.xs[0] 1369*da0073e9SAndroid Build Coastguard Worker 1370*da0073e9SAndroid Build Coastguard Worker @custom_ops.impl_backward(f"{TestCustomOp.test_ns}::foo") 1371*da0073e9SAndroid Build Coastguard Worker def foo_backward(ctx, saved, grad): 1372*da0073e9SAndroid Build Coastguard Worker return {"xs": None} 1373*da0073e9SAndroid Build Coastguard Worker 1374*da0073e9SAndroid Build Coastguard Worker xs = [torch.randn([], requires_grad=True) for _ in range(3)] 1375*da0073e9SAndroid Build Coastguard Worker op = self.get_op(f"{self.test_ns}::foo") 1376*da0073e9SAndroid Build Coastguard Worker y = op(xs) 1377*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "list of gradients"): 1378*da0073e9SAndroid Build Coastguard Worker y.backward() 1379*da0073e9SAndroid Build Coastguard Worker 1380*da0073e9SAndroid Build Coastguard Worker def test_backward_output_differentiability_type(self): 1381*da0073e9SAndroid Build Coastguard Worker @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") 1382*da0073e9SAndroid Build Coastguard Worker def foo(xs: Sequence[torch.Tensor]) -> torch.Tensor: 1383*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError 1384*da0073e9SAndroid Build Coastguard Worker 1385*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "output_differentiability"): 1386*da0073e9SAndroid Build Coastguard Worker 1387*da0073e9SAndroid Build Coastguard Worker @custom_ops.impl_backward( 1388*da0073e9SAndroid Build Coastguard Worker f"{TestCustomOp.test_ns}::foo", output_differentiability=True 1389*da0073e9SAndroid Build Coastguard Worker ) 1390*da0073e9SAndroid Build Coastguard Worker def foo_backward(ctx, saved, grad): 1391*da0073e9SAndroid Build Coastguard Worker return {"xs": None} 1392*da0073e9SAndroid Build Coastguard Worker 1393*da0073e9SAndroid Build Coastguard Worker def test_backward_output_differentiability_numel(self): 1394*da0073e9SAndroid Build Coastguard Worker @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") 1395*da0073e9SAndroid Build Coastguard Worker def foo(xs: Sequence[torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: 1396*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError 1397*da0073e9SAndroid Build Coastguard Worker 1398*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "output_differentiability"): 1399*da0073e9SAndroid Build Coastguard Worker 1400*da0073e9SAndroid Build Coastguard Worker @custom_ops.impl_backward( 1401*da0073e9SAndroid Build Coastguard Worker f"{TestCustomOp.test_ns}::foo", output_differentiability=[True] 1402*da0073e9SAndroid Build Coastguard Worker ) 1403*da0073e9SAndroid Build Coastguard Worker def foo_backward(ctx, saved, grad): 1404*da0073e9SAndroid Build Coastguard Worker return {"xs": None} 1405*da0073e9SAndroid Build Coastguard Worker 1406*da0073e9SAndroid Build Coastguard Worker def test_backward_output_differentiability_tensorlist(self): 1407*da0073e9SAndroid Build Coastguard Worker @custom_ops.custom_op(f"{self.test_ns}::foo") 1408*da0073e9SAndroid Build Coastguard Worker def foo(x: Tensor) -> Tuple[List[Tensor], Tensor]: 1409*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError 1410*da0073e9SAndroid Build Coastguard Worker 1411*da0073e9SAndroid Build Coastguard Worker @custom_ops.impl(f"{self.test_ns}::foo") 1412*da0073e9SAndroid Build Coastguard Worker def foo_impl(x): 1413*da0073e9SAndroid Build Coastguard Worker return [x.clone(), x.clone()], x.clone() 1414*da0073e9SAndroid Build Coastguard Worker 1415*da0073e9SAndroid Build Coastguard Worker @custom_ops.impl_save_for_backward(f"{TestCustomOp.test_ns}::foo") 1416*da0073e9SAndroid Build Coastguard Worker def foo_save_for_backward(inputs, output): 1417*da0073e9SAndroid Build Coastguard Worker return [] 1418*da0073e9SAndroid Build Coastguard Worker 1419*da0073e9SAndroid Build Coastguard Worker @custom_ops.impl_backward( 1420*da0073e9SAndroid Build Coastguard Worker f"{TestCustomOp.test_ns}::foo", output_differentiability=[False, True] 1421*da0073e9SAndroid Build Coastguard Worker ) 1422*da0073e9SAndroid Build Coastguard Worker def foo_backward(ctx, saved, grad_lst, grad): 1423*da0073e9SAndroid Build Coastguard Worker return {"x": grad} 1424*da0073e9SAndroid Build Coastguard Worker 1425*da0073e9SAndroid Build Coastguard Worker op = self.get_op(f"{self.test_ns}::foo") 1426*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3, requires_grad=True) 1427*da0073e9SAndroid Build Coastguard Worker [a, b], c = op(x) 1428*da0073e9SAndroid Build Coastguard Worker self.assertFalse(a.requires_grad) 1429*da0073e9SAndroid Build Coastguard Worker self.assertFalse(b.requires_grad) 1430*da0073e9SAndroid Build Coastguard Worker self.assertTrue(c.requires_grad) 1431*da0073e9SAndroid Build Coastguard Worker 1432*da0073e9SAndroid Build Coastguard Worker def test_backward_output_differentiability_non_tensor(self): 1433*da0073e9SAndroid Build Coastguard Worker @custom_ops.custom_op(f"{self.test_ns}::foo") 1434*da0073e9SAndroid Build Coastguard Worker def foo(x: Tensor) -> Tuple[Tensor, int]: 1435*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError 1436*da0073e9SAndroid Build Coastguard Worker 1437*da0073e9SAndroid Build Coastguard Worker @custom_ops.impl(f"{self.test_ns}::foo") 1438*da0073e9SAndroid Build Coastguard Worker def foo_impl(x): 1439*da0073e9SAndroid Build Coastguard Worker return x.clone(), 3 1440*da0073e9SAndroid Build Coastguard Worker 1441*da0073e9SAndroid Build Coastguard Worker @custom_ops.impl_save_for_backward(f"{TestCustomOp.test_ns}::foo") 1442*da0073e9SAndroid Build Coastguard Worker def foo_save_for_backward(inputs, output): 1443*da0073e9SAndroid Build Coastguard Worker return [] 1444*da0073e9SAndroid Build Coastguard Worker 1445*da0073e9SAndroid Build Coastguard Worker @custom_ops.impl_backward( 1446*da0073e9SAndroid Build Coastguard Worker f"{TestCustomOp.test_ns}::foo", output_differentiability=[True, True] 1447*da0073e9SAndroid Build Coastguard Worker ) 1448*da0073e9SAndroid Build Coastguard Worker def foo_backward(ctx, saved, grad0, grad1): 1449*da0073e9SAndroid Build Coastguard Worker return {"x": grad0} 1450*da0073e9SAndroid Build Coastguard Worker 1451*da0073e9SAndroid Build Coastguard Worker op = self.get_op(f"{self.test_ns}::foo") 1452*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3, requires_grad=True) 1453*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "is not a Tensor"): 1454*da0073e9SAndroid Build Coastguard Worker op(x) 1455*da0073e9SAndroid Build Coastguard Worker 1456*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDA, "requires CUDA") 1457*da0073e9SAndroid Build Coastguard Worker def test_impl_separate(self): 1458*da0073e9SAndroid Build Coastguard Worker @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") 1459*da0073e9SAndroid Build Coastguard Worker def foo(x: torch.Tensor) -> torch.Tensor: 1460*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError 1461*da0073e9SAndroid Build Coastguard Worker 1462*da0073e9SAndroid Build Coastguard Worker @custom_ops.impl(f"{TestCustomOp.test_ns}::foo", device_types="cpu") 1463*da0073e9SAndroid Build Coastguard Worker def foo_cpu(x): 1464*da0073e9SAndroid Build Coastguard Worker return x.sin() 1465*da0073e9SAndroid Build Coastguard Worker 1466*da0073e9SAndroid Build Coastguard Worker @custom_ops.impl(f"{TestCustomOp.test_ns}::foo", device_types="cuda") 1467*da0073e9SAndroid Build Coastguard Worker def foo_cuda(x): 1468*da0073e9SAndroid Build Coastguard Worker return x.cos() 1469*da0073e9SAndroid Build Coastguard Worker 1470*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3) 1471*da0073e9SAndroid Build Coastguard Worker op = self.get_op(f"{self.test_ns}::foo") 1472*da0073e9SAndroid Build Coastguard Worker result = op(x) 1473*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, foo_cpu(x)) 1474*da0073e9SAndroid Build Coastguard Worker 1475*da0073e9SAndroid Build Coastguard Worker x_cuda = x.cuda() 1476*da0073e9SAndroid Build Coastguard Worker op = self.get_op(f"{self.test_ns}::foo") 1477*da0073e9SAndroid Build Coastguard Worker result = op(x_cuda) 1478*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, foo_cuda(x_cuda)) 1479*da0073e9SAndroid Build Coastguard Worker 1480*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDA, "requires CUDA") 1481*da0073e9SAndroid Build Coastguard Worker def test_impl_multiple(self): 1482*da0073e9SAndroid Build Coastguard Worker @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") 1483*da0073e9SAndroid Build Coastguard Worker def foo(x: torch.Tensor) -> torch.Tensor: 1484*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError 1485*da0073e9SAndroid Build Coastguard Worker 1486*da0073e9SAndroid Build Coastguard Worker @custom_ops.impl(f"{TestCustomOp.test_ns}::foo") 1487*da0073e9SAndroid Build Coastguard Worker def foo_impl(x): 1488*da0073e9SAndroid Build Coastguard Worker return x.cos() 1489*da0073e9SAndroid Build Coastguard Worker 1490*da0073e9SAndroid Build Coastguard Worker op = self.get_op(f"{self.test_ns}::foo") 1491*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3) 1492*da0073e9SAndroid Build Coastguard Worker result = op(x) 1493*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, foo_impl(x)) 1494*da0073e9SAndroid Build Coastguard Worker 1495*da0073e9SAndroid Build Coastguard Worker x_cuda = x.cuda() 1496*da0073e9SAndroid Build Coastguard Worker result = op(x_cuda) 1497*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, foo_impl(x_cuda)) 1498*da0073e9SAndroid Build Coastguard Worker 1499*da0073e9SAndroid Build Coastguard Worker def test_impl_abstract_overload(self): 1500*da0073e9SAndroid Build Coastguard Worker lib = self.lib() 1501*da0073e9SAndroid Build Coastguard Worker lib.define("sin.blah(Tensor x) -> Tensor") 1502*da0073e9SAndroid Build Coastguard Worker 1503*da0073e9SAndroid Build Coastguard Worker torch.library.impl_abstract( 1504*da0073e9SAndroid Build Coastguard Worker f"{self.test_ns}::sin.blah", torch.empty_like, lib=lib 1505*da0073e9SAndroid Build Coastguard Worker ) 1506*da0073e9SAndroid Build Coastguard Worker 1507*da0073e9SAndroid Build Coastguard Worker op = self.ns().sin.blah 1508*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3, device="meta") 1509*da0073e9SAndroid Build Coastguard Worker op(x) 1510*da0073e9SAndroid Build Coastguard Worker 1511*da0073e9SAndroid Build Coastguard Worker def test_impl_meta(self): 1512*da0073e9SAndroid Build Coastguard Worker @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") 1513*da0073e9SAndroid Build Coastguard Worker def foo(x: torch.Tensor, dim: int) -> torch.Tensor: 1514*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError 1515*da0073e9SAndroid Build Coastguard Worker 1516*da0073e9SAndroid Build Coastguard Worker @torch.library.impl_abstract(f"{TestCustomOp.test_ns}::foo", lib=self.lib()) 1517*da0073e9SAndroid Build Coastguard Worker def foo_meta(x, dim): 1518*da0073e9SAndroid Build Coastguard Worker output_shape = list(x.shape) 1519*da0073e9SAndroid Build Coastguard Worker del output_shape[dim] 1520*da0073e9SAndroid Build Coastguard Worker return x.new_empty(output_shape) 1521*da0073e9SAndroid Build Coastguard Worker 1522*da0073e9SAndroid Build Coastguard Worker x = torch.randn(2, 3, device="meta") 1523*da0073e9SAndroid Build Coastguard Worker op = self.get_op(f"{self.test_ns}::foo") 1524*da0073e9SAndroid Build Coastguard Worker result = op(x, 1) 1525*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result.shape, foo_meta(x, 1).shape) 1526*da0073e9SAndroid Build Coastguard Worker 1527*da0073e9SAndroid Build Coastguard Worker def test_duplicate_impl(self): 1528*da0073e9SAndroid Build Coastguard Worker @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") 1529*da0073e9SAndroid Build Coastguard Worker def foo(x: torch.Tensor, dim: int) -> torch.Tensor: 1530*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError 1531*da0073e9SAndroid Build Coastguard Worker 1532*da0073e9SAndroid Build Coastguard Worker @torch.library.impl_abstract(f"{TestCustomOp.test_ns}::foo", lib=self.lib()) 1533*da0073e9SAndroid Build Coastguard Worker def foo_meta(x, dim): 1534*da0073e9SAndroid Build Coastguard Worker output_shape = list(x.shape) 1535*da0073e9SAndroid Build Coastguard Worker del output_shape[dim] 1536*da0073e9SAndroid Build Coastguard Worker return x.new_empty(output_shape) 1537*da0073e9SAndroid Build Coastguard Worker 1538*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, r"test_custom_ops.py:\d+"): 1539*da0073e9SAndroid Build Coastguard Worker 1540*da0073e9SAndroid Build Coastguard Worker @torch.library.impl_abstract(f"{TestCustomOp.test_ns}::foo", lib=self.lib()) 1541*da0073e9SAndroid Build Coastguard Worker def foo_meta2(x, dim): 1542*da0073e9SAndroid Build Coastguard Worker output_shape = list(x.shape) 1543*da0073e9SAndroid Build Coastguard Worker del output_shape[dim] 1544*da0073e9SAndroid Build Coastguard Worker return x.new_empty(output_shape) 1545*da0073e9SAndroid Build Coastguard Worker 1546*da0073e9SAndroid Build Coastguard Worker def test_new_data_dependent_symint(self): 1547*da0073e9SAndroid Build Coastguard Worker @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") 1548*da0073e9SAndroid Build Coastguard Worker def foo(x: torch.Tensor) -> torch.Tensor: 1549*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError 1550*da0073e9SAndroid Build Coastguard Worker 1551*da0073e9SAndroid Build Coastguard Worker @torch.library.impl_abstract(f"{TestCustomOp.test_ns}::foo", lib=self.lib()) 1552*da0073e9SAndroid Build Coastguard Worker def foo_meta(x): 1553*da0073e9SAndroid Build Coastguard Worker ctx = torch.library.get_ctx() 1554*da0073e9SAndroid Build Coastguard Worker r = ctx.new_dynamic_size(min=1) 1555*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, "greater than or equal to 0"): 1556*da0073e9SAndroid Build Coastguard Worker ctx.new_dynamic_size(min=-1) 1557*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, "SymInt"): 1558*da0073e9SAndroid Build Coastguard Worker ctx.new_dynamic_size(max=x.numel()) 1559*da0073e9SAndroid Build Coastguard Worker # NB: You must return dynamic sizes! 1560*da0073e9SAndroid Build Coastguard Worker return x.new_empty(r) 1561*da0073e9SAndroid Build Coastguard Worker 1562*da0073e9SAndroid Build Coastguard Worker x = torch.randn(2, 3, device="cpu") 1563*da0073e9SAndroid Build Coastguard Worker op = self.get_op(f"{self.test_ns}::foo") 1564*da0073e9SAndroid Build Coastguard Worker make_fx(op, tracing_mode="symbolic")(x) 1565*da0073e9SAndroid Build Coastguard Worker 1566*da0073e9SAndroid Build Coastguard Worker def test_meta_for_data_dependent_shape_operation(self): 1567*da0073e9SAndroid Build Coastguard Worker x = torch.randn(10, device="meta") 1568*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "data-dependent output shape"): 1569*da0073e9SAndroid Build Coastguard Worker numpy_nonzero(x) 1570*da0073e9SAndroid Build Coastguard Worker 1571*da0073e9SAndroid Build Coastguard Worker def test_basic_make_fx(self): 1572*da0073e9SAndroid Build Coastguard Worker # More serious tests are in our CustomOp opinfo db, 1573*da0073e9SAndroid Build Coastguard Worker # this one is just a sanity check. 1574*da0073e9SAndroid Build Coastguard Worker @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") 1575*da0073e9SAndroid Build Coastguard Worker def foo(x: torch.Tensor) -> torch.Tensor: 1576*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError 1577*da0073e9SAndroid Build Coastguard Worker 1578*da0073e9SAndroid Build Coastguard Worker @torch.library.impl_abstract(f"{TestCustomOp.test_ns}::foo", lib=self.lib()) 1579*da0073e9SAndroid Build Coastguard Worker def foo_meta(x): 1580*da0073e9SAndroid Build Coastguard Worker return x.sum() 1581*da0073e9SAndroid Build Coastguard Worker 1582*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3) 1583*da0073e9SAndroid Build Coastguard Worker op = self.get_op(f"{self.test_ns}::foo") 1584*da0073e9SAndroid Build Coastguard Worker gm = make_fx(op, tracing_mode="symbolic")(x) 1585*da0073e9SAndroid Build Coastguard Worker self.assertTrue(f"{TestCustomOp.test_ns}.foo" in gm.code) 1586*da0073e9SAndroid Build Coastguard Worker 1587*da0073e9SAndroid Build Coastguard Worker def test_not_implemented_error(self): 1588*da0073e9SAndroid Build Coastguard Worker @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") 1589*da0073e9SAndroid Build Coastguard Worker def foo(x: torch.Tensor) -> torch.Tensor: 1590*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError 1591*da0073e9SAndroid Build Coastguard Worker 1592*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3) 1593*da0073e9SAndroid Build Coastguard Worker op = self.get_op(f"{self.test_ns}::foo") 1594*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(NotImplementedError, "cpu impl registered"): 1595*da0073e9SAndroid Build Coastguard Worker op(x) 1596*da0073e9SAndroid Build Coastguard Worker 1597*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3, device="meta") 1598*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(NotImplementedError, "no fake impl or Meta kernel"): 1599*da0073e9SAndroid Build Coastguard Worker op(x) 1600*da0073e9SAndroid Build Coastguard Worker 1601*da0073e9SAndroid Build Coastguard Worker @custom_ops.custom_op(f"{TestCustomOp.test_ns}::bar") 1602*da0073e9SAndroid Build Coastguard Worker def bar(sizes: Sequence[int]) -> torch.Tensor: 1603*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError 1604*da0073e9SAndroid Build Coastguard Worker 1605*da0073e9SAndroid Build Coastguard Worker op = self.get_op(f"{self.test_ns}::bar") 1606*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(NotImplementedError, "no Tensor inputs"): 1607*da0073e9SAndroid Build Coastguard Worker op((1, 2, 3)) 1608*da0073e9SAndroid Build Coastguard Worker 1609*da0073e9SAndroid Build Coastguard Worker def test_data_dependent_basic(self): 1610*da0073e9SAndroid Build Coastguard Worker x = torch.randn(5, 5) 1611*da0073e9SAndroid Build Coastguard Worker gm = make_fx(numpy_nonzero, tracing_mode="symbolic")(x) 1612*da0073e9SAndroid Build Coastguard Worker self.assertTrue("nonzero" in gm.code) 1613*da0073e9SAndroid Build Coastguard Worker 1614*da0073e9SAndroid Build Coastguard Worker def test_data_dependent_fake_tracing(self): 1615*da0073e9SAndroid Build Coastguard Worker x = torch.randn(5, 5) 1616*da0073e9SAndroid Build Coastguard Worker # We've updated to attempt to use unbacked symints even for fake 1617*da0073e9SAndroid Build Coastguard Worker # tracing 1618*da0073e9SAndroid Build Coastguard Worker make_fx(numpy_nonzero, tracing_mode="fake")(x) 1619*da0073e9SAndroid Build Coastguard Worker 1620*da0073e9SAndroid Build Coastguard Worker def test_symints(self): 1621*da0073e9SAndroid Build Coastguard Worker def f(x): 1622*da0073e9SAndroid Build Coastguard Worker return torch.ops._torch_testing.numpy_view_copy(x, x.shape) 1623*da0073e9SAndroid Build Coastguard Worker 1624*da0073e9SAndroid Build Coastguard Worker x = torch.randn(2, 3, 4) 1625*da0073e9SAndroid Build Coastguard Worker gm = make_fx(f, tracing_mode="symbolic")(x) 1626*da0073e9SAndroid Build Coastguard Worker result = gm(x) 1627*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, f(x)) 1628*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 1629*da0073e9SAndroid Build Coastguard Worker gm.code.strip(), 1630*da0073e9SAndroid Build Coastguard Worker """\ 1631*da0073e9SAndroid Build Coastguard Workerdef forward(self, x_1): 1632*da0073e9SAndroid Build Coastguard Worker sym_size_int = torch.ops.aten.sym_size.int(x_1, 0) 1633*da0073e9SAndroid Build Coastguard Worker sym_size_int_1 = torch.ops.aten.sym_size.int(x_1, 1) 1634*da0073e9SAndroid Build Coastguard Worker sym_size_int_2 = torch.ops.aten.sym_size.int(x_1, 2) 1635*da0073e9SAndroid Build Coastguard Worker numpy_view_copy = torch.ops._torch_testing.numpy_view_copy.default(x_1, [sym_size_int, sym_size_int_1, sym_size_int_2]); x_1 = sym_size_int = sym_size_int_1 = sym_size_int_2 = None 1636*da0073e9SAndroid Build Coastguard Worker return numpy_view_copy""", # noqa: B950 1637*da0073e9SAndroid Build Coastguard Worker ) 1638*da0073e9SAndroid Build Coastguard Worker 1639*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work on windows") 1640*da0073e9SAndroid Build Coastguard Worker def test_data_dependent_compile(self): 1641*da0073e9SAndroid Build Coastguard Worker import torch._dynamo.testing 1642*da0073e9SAndroid Build Coastguard Worker from torch._dynamo.utils import counters 1643*da0073e9SAndroid Build Coastguard Worker 1644*da0073e9SAndroid Build Coastguard Worker counters.clear() 1645*da0073e9SAndroid Build Coastguard Worker cnt = torch._dynamo.testing.CompileCounter() 1646*da0073e9SAndroid Build Coastguard Worker 1647*da0073e9SAndroid Build Coastguard Worker @torch.compile(backend=cnt) 1648*da0073e9SAndroid Build Coastguard Worker def f(x): 1649*da0073e9SAndroid Build Coastguard Worker return numpy_nonzero(x.clone()).clone() 1650*da0073e9SAndroid Build Coastguard Worker 1651*da0073e9SAndroid Build Coastguard Worker f(torch.randn(10)) 1652*da0073e9SAndroid Build Coastguard Worker 1653*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(counters["graph_break"]), 1) 1654*da0073e9SAndroid Build Coastguard Worker self.assertEqual(next(iter(counters["graph_break"].values())), 1) 1655*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 1656*da0073e9SAndroid Build Coastguard Worker next(iter(counters["graph_break"].keys())).replace(";", "\n"), 1657*da0073e9SAndroid Build Coastguard Worker """\ 1658*da0073e9SAndroid Build Coastguard Workerdynamic shape operator: _torch_testing.numpy_nonzero.default 1659*da0073e9SAndroid Build Coastguard Worker to enable, set torch._dynamo.config.capture_dynamic_output_shape_ops = True""", 1660*da0073e9SAndroid Build Coastguard Worker ) 1661*da0073e9SAndroid Build Coastguard Worker 1662*da0073e9SAndroid Build Coastguard Worker # pre-existing problem: torch.compile(dynamic=True) will, by default, 1663*da0073e9SAndroid Build Coastguard Worker # graph break on data-dependent operations. Eventually we'll make it so 1664*da0073e9SAndroid Build Coastguard Worker # that it never graph breaks on data-dependent operations. 1665*da0073e9SAndroid Build Coastguard Worker @unittest.expectedFailure 1666*da0073e9SAndroid Build Coastguard Worker def test_data_dependent_nms_dynamic_compile(self): 1667*da0073e9SAndroid Build Coastguard Worker import torch._dynamo.testing 1668*da0073e9SAndroid Build Coastguard Worker from torch._dynamo.utils import counters 1669*da0073e9SAndroid Build Coastguard Worker 1670*da0073e9SAndroid Build Coastguard Worker counters.clear() 1671*da0073e9SAndroid Build Coastguard Worker cnt = torch._dynamo.testing.CompileCounter() 1672*da0073e9SAndroid Build Coastguard Worker 1673*da0073e9SAndroid Build Coastguard Worker @torch.compile(backend=cnt, dynamic=True) 1674*da0073e9SAndroid Build Coastguard Worker def f(x, s, i): 1675*da0073e9SAndroid Build Coastguard Worker return torch.ops._torch_testing.numpy_nms(x.clone(), s, i).clone() 1676*da0073e9SAndroid Build Coastguard Worker 1677*da0073e9SAndroid Build Coastguard Worker f(torch.randn(20, 4), torch.randn(20), 0.1) 1678*da0073e9SAndroid Build Coastguard Worker 1679*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(counters["graph_break"]), 0) 1680*da0073e9SAndroid Build Coastguard Worker 1681*da0073e9SAndroid Build Coastguard Worker def test_impl_on_existing_op(self): 1682*da0073e9SAndroid Build Coastguard Worker lib = self.lib() 1683*da0073e9SAndroid Build Coastguard Worker lib.define("foo(Tensor x) -> Tensor") 1684*da0073e9SAndroid Build Coastguard Worker qualname = f"{self.test_ns}::foo" 1685*da0073e9SAndroid Build Coastguard Worker 1686*da0073e9SAndroid Build Coastguard Worker @torch._custom_ops.impl(qualname) 1687*da0073e9SAndroid Build Coastguard Worker def foo_impl(x): 1688*da0073e9SAndroid Build Coastguard Worker return x.sin() 1689*da0073e9SAndroid Build Coastguard Worker 1690*da0073e9SAndroid Build Coastguard Worker op = self.get_op(qualname) 1691*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3) 1692*da0073e9SAndroid Build Coastguard Worker result = op(x) 1693*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, x.sin()) 1694*da0073e9SAndroid Build Coastguard Worker 1695*da0073e9SAndroid Build Coastguard Worker @parametrize( 1696*da0073e9SAndroid Build Coastguard Worker "key", ["CPU", "CUDA", "CompositeImplicitAutograd", "CompositeExplicitAutograd"] 1697*da0073e9SAndroid Build Coastguard Worker ) 1698*da0073e9SAndroid Build Coastguard Worker def test_impl_on_existing_op_with_cpu_registration(self, key): 1699*da0073e9SAndroid Build Coastguard Worker lib = self.lib() 1700*da0073e9SAndroid Build Coastguard Worker lib.define("foo(Tensor x) -> Tensor") 1701*da0073e9SAndroid Build Coastguard Worker qualname = f"{self.test_ns}::foo" 1702*da0073e9SAndroid Build Coastguard Worker 1703*da0073e9SAndroid Build Coastguard Worker def foo_impl(x): 1704*da0073e9SAndroid Build Coastguard Worker return x.sin() 1705*da0073e9SAndroid Build Coastguard Worker 1706*da0073e9SAndroid Build Coastguard Worker lib.impl("foo", foo_impl, key) 1707*da0073e9SAndroid Build Coastguard Worker op = self.get_op(qualname) 1708*da0073e9SAndroid Build Coastguard Worker 1709*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "already has an implementation"): 1710*da0073e9SAndroid Build Coastguard Worker custom_ops.impl(qualname, func=foo_impl) 1711*da0073e9SAndroid Build Coastguard Worker 1712*da0073e9SAndroid Build Coastguard Worker def test_abstract_impl_on_existing_op(self): 1713*da0073e9SAndroid Build Coastguard Worker lib = self.lib() 1714*da0073e9SAndroid Build Coastguard Worker lib.define("foo(Tensor x) -> Tensor") 1715*da0073e9SAndroid Build Coastguard Worker qualname = f"{self.test_ns}::foo" 1716*da0073e9SAndroid Build Coastguard Worker 1717*da0073e9SAndroid Build Coastguard Worker @torch.library.impl_abstract(qualname, lib=self.lib()) 1718*da0073e9SAndroid Build Coastguard Worker def foo_impl(x): 1719*da0073e9SAndroid Build Coastguard Worker return x.sin() 1720*da0073e9SAndroid Build Coastguard Worker 1721*da0073e9SAndroid Build Coastguard Worker op = self.get_op(qualname) 1722*da0073e9SAndroid Build Coastguard Worker with torch._subclasses.FakeTensorMode(): 1723*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3) 1724*da0073e9SAndroid Build Coastguard Worker result = op(x) 1725*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result.shape, x.shape) 1726*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result.stride(), x.stride()) 1727*da0073e9SAndroid Build Coastguard Worker 1728*da0073e9SAndroid Build Coastguard Worker def test_abstract_impl_on_existing_op_with_meta(self): 1729*da0073e9SAndroid Build Coastguard Worker lib = self.lib() 1730*da0073e9SAndroid Build Coastguard Worker lib.define("foo(Tensor x) -> Tensor") 1731*da0073e9SAndroid Build Coastguard Worker qualname = f"{self.test_ns}::foo" 1732*da0073e9SAndroid Build Coastguard Worker 1733*da0073e9SAndroid Build Coastguard Worker def foo_impl(x): 1734*da0073e9SAndroid Build Coastguard Worker return x.sin() 1735*da0073e9SAndroid Build Coastguard Worker 1736*da0073e9SAndroid Build Coastguard Worker lib.impl("foo", foo_impl, "Meta") 1737*da0073e9SAndroid Build Coastguard Worker op = self.get_op(qualname) 1738*da0073e9SAndroid Build Coastguard Worker 1739*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, r"already has .*Meta implementation"): 1740*da0073e9SAndroid Build Coastguard Worker torch.library.impl_abstract(qualname, func=foo_impl, lib=self.lib()) 1741*da0073e9SAndroid Build Coastguard Worker 1742*da0073e9SAndroid Build Coastguard Worker def test_abstract_impl_on_existing_op_with_CompositeImplicitAutograd(self): 1743*da0073e9SAndroid Build Coastguard Worker lib = self.lib() 1744*da0073e9SAndroid Build Coastguard Worker lib.define("foo(Tensor x) -> Tensor") 1745*da0073e9SAndroid Build Coastguard Worker qualname = f"{self.test_ns}::foo" 1746*da0073e9SAndroid Build Coastguard Worker 1747*da0073e9SAndroid Build Coastguard Worker def foo_impl(x): 1748*da0073e9SAndroid Build Coastguard Worker return x.sin() 1749*da0073e9SAndroid Build Coastguard Worker 1750*da0073e9SAndroid Build Coastguard Worker lib.impl("foo", foo_impl, "CompositeImplicitAutograd") 1751*da0073e9SAndroid Build Coastguard Worker op = self.get_op(qualname) 1752*da0073e9SAndroid Build Coastguard Worker 1753*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "CompositeImplicitAutograd"): 1754*da0073e9SAndroid Build Coastguard Worker torch.library.impl_abstract(qualname, func=foo_impl, lib=self.lib()) 1755*da0073e9SAndroid Build Coastguard Worker 1756*da0073e9SAndroid Build Coastguard Worker def test_abstract_impl_on_existing_op_with_CompositeExplicitAutograd(self): 1757*da0073e9SAndroid Build Coastguard Worker lib = self.lib() 1758*da0073e9SAndroid Build Coastguard Worker lib.define("foo(Tensor x) -> Tensor") 1759*da0073e9SAndroid Build Coastguard Worker qualname = f"{self.test_ns}::foo" 1760*da0073e9SAndroid Build Coastguard Worker 1761*da0073e9SAndroid Build Coastguard Worker def foo_impl(x): 1762*da0073e9SAndroid Build Coastguard Worker return x.sin() 1763*da0073e9SAndroid Build Coastguard Worker 1764*da0073e9SAndroid Build Coastguard Worker lib.impl("foo", foo_impl, "CompositeExplicitAutograd") 1765*da0073e9SAndroid Build Coastguard Worker op = self.get_op(qualname) 1766*da0073e9SAndroid Build Coastguard Worker 1767*da0073e9SAndroid Build Coastguard Worker torch.library.impl_abstract(qualname, func=lambda x: x.sum(), lib=self.lib()) 1768*da0073e9SAndroid Build Coastguard Worker with torch._subclasses.FakeTensorMode(): 1769*da0073e9SAndroid Build Coastguard Worker x = torch.randn(10) 1770*da0073e9SAndroid Build Coastguard Worker result = op(x) 1771*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result.shape, ()) 1772*da0073e9SAndroid Build Coastguard Worker 1773*da0073e9SAndroid Build Coastguard Worker def _test_backward_impl_raises(self, qualname, err_regex): 1774*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, err_regex): 1775*da0073e9SAndroid Build Coastguard Worker 1776*da0073e9SAndroid Build Coastguard Worker @custom_ops.impl_save_for_backward(qualname) 1777*da0073e9SAndroid Build Coastguard Worker def foo2(x): 1778*da0073e9SAndroid Build Coastguard Worker return 1779*da0073e9SAndroid Build Coastguard Worker 1780*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, err_regex): 1781*da0073e9SAndroid Build Coastguard Worker 1782*da0073e9SAndroid Build Coastguard Worker @custom_ops.impl_backward(qualname) 1783*da0073e9SAndroid Build Coastguard Worker def foo3(x): 1784*da0073e9SAndroid Build Coastguard Worker return 1785*da0073e9SAndroid Build Coastguard Worker 1786*da0073e9SAndroid Build Coastguard Worker def test_backward_impl_on_existing_op_incorrect_schema_views(self): 1787*da0073e9SAndroid Build Coastguard Worker lib = self.lib() 1788*da0073e9SAndroid Build Coastguard Worker lib.define("foo(Tensor(a) x) -> Tensor(a)") 1789*da0073e9SAndroid Build Coastguard Worker qualname = f"{self.test_ns}::foo" 1790*da0073e9SAndroid Build Coastguard Worker self._test_backward_impl_raises(qualname, "operator that returns views") 1791*da0073e9SAndroid Build Coastguard Worker 1792*da0073e9SAndroid Build Coastguard Worker def test_backward_impl_on_existing_op_incorrect_schema_mutable(self): 1793*da0073e9SAndroid Build Coastguard Worker lib = self.lib() 1794*da0073e9SAndroid Build Coastguard Worker lib.define("foo(Tensor(a!) x) -> Tensor") 1795*da0073e9SAndroid Build Coastguard Worker qualname = f"{self.test_ns}::foo" 1796*da0073e9SAndroid Build Coastguard Worker self._test_backward_impl_raises(qualname, "non-functional") 1797*da0073e9SAndroid Build Coastguard Worker 1798*da0073e9SAndroid Build Coastguard Worker def test_backward_impl_on_existing_op_incorrect_schema_no_output(self): 1799*da0073e9SAndroid Build Coastguard Worker lib = self.lib() 1800*da0073e9SAndroid Build Coastguard Worker lib.define("foo(Tensor x) -> ()") 1801*da0073e9SAndroid Build Coastguard Worker qualname = f"{self.test_ns}::foo" 1802*da0073e9SAndroid Build Coastguard Worker self._test_backward_impl_raises(qualname, "no returns") 1803*da0073e9SAndroid Build Coastguard Worker 1804*da0073e9SAndroid Build Coastguard Worker def test_backward_impl_on_existing_op_CompositeImplicitAutograd(self): 1805*da0073e9SAndroid Build Coastguard Worker lib = self.lib() 1806*da0073e9SAndroid Build Coastguard Worker lib.define("foo(Tensor x) -> Tensor") 1807*da0073e9SAndroid Build Coastguard Worker qualname = f"{self.test_ns}::foo" 1808*da0073e9SAndroid Build Coastguard Worker lib.impl("foo", lambda x: x.sin().cos(), "CompositeImplicitAutograd") 1809*da0073e9SAndroid Build Coastguard Worker self._test_backward_impl_raises(qualname, "CompositeImplicitAutograd") 1810*da0073e9SAndroid Build Coastguard Worker 1811*da0073e9SAndroid Build Coastguard Worker @parametrize("key", ["Autograd", "AutogradCPU", "AutogradCUDA"]) 1812*da0073e9SAndroid Build Coastguard Worker def test_backward_impl_on_existing_op_with_key(self, key): 1813*da0073e9SAndroid Build Coastguard Worker lib = self.lib() 1814*da0073e9SAndroid Build Coastguard Worker lib.define("foo(Tensor x) -> Tensor") 1815*da0073e9SAndroid Build Coastguard Worker qualname = f"{self.test_ns}::foo" 1816*da0073e9SAndroid Build Coastguard Worker lib.impl("foo", lambda x: x.sin().cos(), key) 1817*da0073e9SAndroid Build Coastguard Worker self._test_backward_impl_raises(qualname, key) 1818*da0073e9SAndroid Build Coastguard Worker 1819*da0073e9SAndroid Build Coastguard Worker def test_is_functional_schema(self): 1820*da0073e9SAndroid Build Coastguard Worker tests = { 1821*da0073e9SAndroid Build Coastguard Worker "foo(Tensor x) -> Tensor": True, 1822*da0073e9SAndroid Build Coastguard Worker "foo(Tensor(a) x) -> Tensor": True, 1823*da0073e9SAndroid Build Coastguard Worker "foo(Tensor(a!) x) -> Tensor": False, 1824*da0073e9SAndroid Build Coastguard Worker "foo(Tensor(a) x) -> Tensor(a)": False, 1825*da0073e9SAndroid Build Coastguard Worker "foo(Tensor x) -> ()": False, 1826*da0073e9SAndroid Build Coastguard Worker } 1827*da0073e9SAndroid Build Coastguard Worker for schema_str, expected in tests.items(): 1828*da0073e9SAndroid Build Coastguard Worker res = torch._library.utils.is_functional_schema(schema_str) 1829*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res, expected) 1830*da0073e9SAndroid Build Coastguard Worker 1831*da0073e9SAndroid Build Coastguard Worker from torchgen.model import FunctionSchema 1832*da0073e9SAndroid Build Coastguard Worker 1833*da0073e9SAndroid Build Coastguard Worker schema = FunctionSchema.parse(schema_str) 1834*da0073e9SAndroid Build Coastguard Worker res = torch._library.utils.is_functional_schema(schema) 1835*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res, expected) 1836*da0073e9SAndroid Build Coastguard Worker 1837*da0073e9SAndroid Build Coastguard Worker schema = torch._C.parse_schema(schema_str) 1838*da0073e9SAndroid Build Coastguard Worker res = torch._library.utils.is_functional_schema(schema) 1839*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res, expected) 1840*da0073e9SAndroid Build Coastguard Worker 1841*da0073e9SAndroid Build Coastguard Worker def test_incorrect_schema_types(self): 1842*da0073e9SAndroid Build Coastguard Worker with torch.library._scoped_library("mylib", "FRAGMENT") as lib: 1843*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "unknown type specifier"): 1844*da0073e9SAndroid Build Coastguard Worker lib.define("foo12(Tensor a) -> asdfasdf") 1845*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "unknown type specifier"): 1846*da0073e9SAndroid Build Coastguard Worker lib.define("foo12(asdf a) -> Tensor") 1847*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Use `SymInt` or `int`"): 1848*da0073e9SAndroid Build Coastguard Worker lib.define("foo12(int64_t a) -> Tensor") 1849*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Use `float`"): 1850*da0073e9SAndroid Build Coastguard Worker lib.define("foo12(double a) -> Tensor") 1851*da0073e9SAndroid Build Coastguard Worker 1852*da0073e9SAndroid Build Coastguard Worker def test_is_tensorlist_like_type(self): 1853*da0073e9SAndroid Build Coastguard Worker tensorlists = [ 1854*da0073e9SAndroid Build Coastguard Worker # Tensor[] 1855*da0073e9SAndroid Build Coastguard Worker torch.ops.aten.where.default._schema.returns[0].type, 1856*da0073e9SAndroid Build Coastguard Worker # Tensor?[] 1857*da0073e9SAndroid Build Coastguard Worker torch.ops.aten.index.Tensor._schema.arguments[1].type, 1858*da0073e9SAndroid Build Coastguard Worker # Tensor[]? 1859*da0073e9SAndroid Build Coastguard Worker torch._C.parse_schema("foo(Tensor[]? x) -> ()").arguments[0].type, 1860*da0073e9SAndroid Build Coastguard Worker # Tensor?[]? 1861*da0073e9SAndroid Build Coastguard Worker torch._C.parse_schema("foo(Tensor?[]? x) -> ()").arguments[0].type, 1862*da0073e9SAndroid Build Coastguard Worker ] 1863*da0073e9SAndroid Build Coastguard Worker non_tensorlists = [ 1864*da0073e9SAndroid Build Coastguard Worker # Tensor 1865*da0073e9SAndroid Build Coastguard Worker torch.ops.aten.sin.default._schema.arguments[0].type, 1866*da0073e9SAndroid Build Coastguard Worker # IntList 1867*da0073e9SAndroid Build Coastguard Worker torch.ops.aten.sum.dim_IntList._schema.arguments[1].type, 1868*da0073e9SAndroid Build Coastguard Worker ] 1869*da0073e9SAndroid Build Coastguard Worker for a in tensorlists: 1870*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch._library.utils.is_tensorlist_like_type(a)) 1871*da0073e9SAndroid Build Coastguard Worker for a in non_tensorlists: 1872*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch._library.utils.is_tensorlist_like_type(a)) 1873*da0073e9SAndroid Build Coastguard Worker 1874*da0073e9SAndroid Build Coastguard Worker def test_backward_impl_on_existing_op(self): 1875*da0073e9SAndroid Build Coastguard Worker lib = self.lib() 1876*da0073e9SAndroid Build Coastguard Worker lib.define("foo(Tensor x) -> Tensor") 1877*da0073e9SAndroid Build Coastguard Worker qualname = f"{self.test_ns}::foo" 1878*da0073e9SAndroid Build Coastguard Worker 1879*da0073e9SAndroid Build Coastguard Worker @custom_ops.impl(qualname) 1880*da0073e9SAndroid Build Coastguard Worker def foo_impl(x): 1881*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 1882*da0073e9SAndroid Build Coastguard Worker return x.sin() 1883*da0073e9SAndroid Build Coastguard Worker 1884*da0073e9SAndroid Build Coastguard Worker @custom_ops.impl_save_for_backward(qualname) 1885*da0073e9SAndroid Build Coastguard Worker def foo_save_for_backward(inputs, output): 1886*da0073e9SAndroid Build Coastguard Worker return inputs.x 1887*da0073e9SAndroid Build Coastguard Worker 1888*da0073e9SAndroid Build Coastguard Worker @custom_ops.impl_backward(qualname) 1889*da0073e9SAndroid Build Coastguard Worker def foo_backward(ctx, saved, grad_out): 1890*da0073e9SAndroid Build Coastguard Worker return {"x": grad_out * saved.cos()} 1891*da0073e9SAndroid Build Coastguard Worker 1892*da0073e9SAndroid Build Coastguard Worker op = self.get_op(qualname) 1893*da0073e9SAndroid Build Coastguard Worker x = torch.randn([], requires_grad=True) 1894*da0073e9SAndroid Build Coastguard Worker y = op(x) 1895*da0073e9SAndroid Build Coastguard Worker (gx,) = torch.autograd.grad(y, x) 1896*da0073e9SAndroid Build Coastguard Worker self.assertEqual(gx, x.cos()) 1897*da0073e9SAndroid Build Coastguard Worker 1898*da0073e9SAndroid Build Coastguard Worker @parametrize( 1899*da0073e9SAndroid Build Coastguard Worker "tags", 1900*da0073e9SAndroid Build Coastguard Worker [ 1901*da0073e9SAndroid Build Coastguard Worker subtest(torch.Tag.pointwise, "single"), 1902*da0073e9SAndroid Build Coastguard Worker subtest((torch.Tag.pointwise,), "tuple"), 1903*da0073e9SAndroid Build Coastguard Worker subtest([torch.Tag.pointwise], "list"), 1904*da0073e9SAndroid Build Coastguard Worker ], 1905*da0073e9SAndroid Build Coastguard Worker ) 1906*da0073e9SAndroid Build Coastguard Worker def test_define_with_tags(self, tags): 1907*da0073e9SAndroid Build Coastguard Worker lib = self.lib() 1908*da0073e9SAndroid Build Coastguard Worker tags = (torch.Tag.pointwise,) 1909*da0073e9SAndroid Build Coastguard Worker torch.library.define( 1910*da0073e9SAndroid Build Coastguard Worker f"{self.test_ns}::foo", "(Tensor x) -> Tensor", lib=lib, tags=tags 1911*da0073e9SAndroid Build Coastguard Worker ) 1912*da0073e9SAndroid Build Coastguard Worker actual = self.ns().foo.default.tags 1913*da0073e9SAndroid Build Coastguard Worker self.assertTrue(isinstance(actual, list)) 1914*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual, list(tags)) 1915*da0073e9SAndroid Build Coastguard Worker 1916*da0073e9SAndroid Build Coastguard Worker def test_builtin_aten_ops_are_pt2_compliant(self): 1917*da0073e9SAndroid Build Coastguard Worker for op in [torch.ops.aten.sin.default, torch.ops.aten.sum.dim_IntList]: 1918*da0073e9SAndroid Build Coastguard Worker self.assertIn(torch.Tag.pt2_compliant_tag, op.tags) 1919*da0073e9SAndroid Build Coastguard Worker 1920*da0073e9SAndroid Build Coastguard Worker def test_builtin_torchscript_ops(self): 1921*da0073e9SAndroid Build Coastguard Worker for op in [torch.ops.aten.sub.complex, torch.ops.aten.mul.complex]: 1922*da0073e9SAndroid Build Coastguard Worker self.assertIn(torch.Tag.pt2_compliant_tag, op.tags) 1923*da0073e9SAndroid Build Coastguard Worker 1924*da0073e9SAndroid Build Coastguard Worker def test_autogen_aten_ops_are_pt2_compliant(self): 1925*da0073e9SAndroid Build Coastguard Worker for op in [torch.ops.aten.fill.Tensor_out]: 1926*da0073e9SAndroid Build Coastguard Worker self.assertIn(torch.Tag.generated, op.tags) 1927*da0073e9SAndroid Build Coastguard Worker self.assertIn(torch.Tag.pt2_compliant_tag, op.tags) 1928*da0073e9SAndroid Build Coastguard Worker 1929*da0073e9SAndroid Build Coastguard Worker def test_resolve_packet(self): 1930*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3) 1931*da0073e9SAndroid Build Coastguard Worker result = torch._C._jit_resolve_packet("aten::sum", x) 1932*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, "default") 1933*da0073e9SAndroid Build Coastguard Worker 1934*da0073e9SAndroid Build Coastguard Worker result = torch._C._jit_resolve_packet("aten::sum", x, dim=1) 1935*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, "dim_IntList") 1936*da0073e9SAndroid Build Coastguard Worker 1937*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "failed to match any schema"): 1938*da0073e9SAndroid Build Coastguard Worker result = torch._C._jit_resolve_packet("aten::sum", x, x, x) 1939*da0073e9SAndroid Build Coastguard Worker 1940*da0073e9SAndroid Build Coastguard Worker def test_define_bad_schema(self): 1941*da0073e9SAndroid Build Coastguard Worker lib = self.lib() 1942*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, "expected schema to look like"): 1943*da0073e9SAndroid Build Coastguard Worker torch.library.define(f"{self.test_ns}::foo", "foo(Tensor x) -> Tensor") 1944*da0073e9SAndroid Build Coastguard Worker 1945*da0073e9SAndroid Build Coastguard Worker def test_define_and_impl(self): 1946*da0073e9SAndroid Build Coastguard Worker lib = self.lib() 1947*da0073e9SAndroid Build Coastguard Worker torch.library.define(f"{self.test_ns}::foo", "(Tensor x) -> Tensor", lib=lib) 1948*da0073e9SAndroid Build Coastguard Worker 1949*da0073e9SAndroid Build Coastguard Worker @torch.library.impl(f"{self.test_ns}::foo", "CPU", lib=lib) 1950*da0073e9SAndroid Build Coastguard Worker def f(x): 1951*da0073e9SAndroid Build Coastguard Worker return torch.from_numpy(np.sin(x.numpy())) 1952*da0073e9SAndroid Build Coastguard Worker 1953*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3) 1954*da0073e9SAndroid Build Coastguard Worker y = self.ns().foo(x) 1955*da0073e9SAndroid Build Coastguard Worker assert torch.allclose(y, x.sin()) 1956*da0073e9SAndroid Build Coastguard Worker 1957*da0073e9SAndroid Build Coastguard Worker def test_define_validation(self): 1958*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, "namespace"): 1959*da0073e9SAndroid Build Coastguard Worker torch.library.define("foo", "(Tensor x) -> Tensor") 1960*da0073e9SAndroid Build Coastguard Worker 1961*da0073e9SAndroid Build Coastguard Worker def test_legacy_define(self): 1962*da0073e9SAndroid Build Coastguard Worker lib = self.lib() 1963*da0073e9SAndroid Build Coastguard Worker 1964*da0073e9SAndroid Build Coastguard Worker @torch.library.define(lib, "foo(Tensor x) -> Tensor") 1965*da0073e9SAndroid Build Coastguard Worker def f(x): 1966*da0073e9SAndroid Build Coastguard Worker return torch.from_numpy(np.sin(x.numpy())) 1967*da0073e9SAndroid Build Coastguard Worker 1968*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3) 1969*da0073e9SAndroid Build Coastguard Worker y = self.ns().foo(x) 1970*da0073e9SAndroid Build Coastguard Worker assert torch.allclose(y, x.sin()) 1971*da0073e9SAndroid Build Coastguard Worker 1972*da0073e9SAndroid Build Coastguard Worker def test_impl_function(self): 1973*da0073e9SAndroid Build Coastguard Worker lib = self.lib() 1974*da0073e9SAndroid Build Coastguard Worker torch.library.define(f"{self.test_ns}::foo", "(Tensor x) -> Tensor", lib=lib) 1975*da0073e9SAndroid Build Coastguard Worker 1976*da0073e9SAndroid Build Coastguard Worker def f(x): 1977*da0073e9SAndroid Build Coastguard Worker return torch.from_numpy(np.sin(x.numpy())) 1978*da0073e9SAndroid Build Coastguard Worker 1979*da0073e9SAndroid Build Coastguard Worker torch.library.impl(f"{self.test_ns}::foo", "CPU", f, lib=lib) 1980*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3) 1981*da0073e9SAndroid Build Coastguard Worker y = self.ns().foo(x) 1982*da0073e9SAndroid Build Coastguard Worker assert torch.allclose(y, x.sin()) 1983*da0073e9SAndroid Build Coastguard Worker 1984*da0073e9SAndroid Build Coastguard Worker def test_legacy_impl(self): 1985*da0073e9SAndroid Build Coastguard Worker lib = self.lib() 1986*da0073e9SAndroid Build Coastguard Worker torch.library.define(f"{self.test_ns}::foo", "(Tensor x) -> Tensor", lib=lib) 1987*da0073e9SAndroid Build Coastguard Worker 1988*da0073e9SAndroid Build Coastguard Worker @torch.library.impl(lib, "foo", "CPU") 1989*da0073e9SAndroid Build Coastguard Worker def f(x): 1990*da0073e9SAndroid Build Coastguard Worker return torch.from_numpy(np.sin(x.numpy())) 1991*da0073e9SAndroid Build Coastguard Worker 1992*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3) 1993*da0073e9SAndroid Build Coastguard Worker y = self.ns().foo(x) 1994*da0073e9SAndroid Build Coastguard Worker assert torch.allclose(y, x.sin()) 1995*da0073e9SAndroid Build Coastguard Worker 1996*da0073e9SAndroid Build Coastguard Worker def test_defined_in_python(self): 1997*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch.ops.aten.sin.default._defined_in_python) 1998*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch.ops.aten.sum.dim_IntList._defined_in_python) 1999*da0073e9SAndroid Build Coastguard Worker 2000*da0073e9SAndroid Build Coastguard Worker lib = self.lib() 2001*da0073e9SAndroid Build Coastguard Worker torch.library.define("{self._test_ns}::foo", "(Tensor x) -> Tensor", lib=lib) 2002*da0073e9SAndroid Build Coastguard Worker ns = self.ns() 2003*da0073e9SAndroid Build Coastguard Worker self.assertTrue(ns.foo.default._defined_in_python) 2004*da0073e9SAndroid Build Coastguard Worker 2005*da0073e9SAndroid Build Coastguard Worker torch.library.define( 2006*da0073e9SAndroid Build Coastguard Worker "{self._test_ns}::bar.overload", "(Tensor x) -> Tensor", lib=lib 2007*da0073e9SAndroid Build Coastguard Worker ) 2008*da0073e9SAndroid Build Coastguard Worker self.assertTrue(ns.bar.overload._defined_in_python) 2009*da0073e9SAndroid Build Coastguard Worker 2010*da0073e9SAndroid Build Coastguard Worker def _test_impl_device(self, name, types, device): 2011*da0073e9SAndroid Build Coastguard Worker lib = self.lib() 2012*da0073e9SAndroid Build Coastguard Worker torch.library.define(f"{self.test_ns}::{name}", "(Tensor x) -> Tensor", lib=lib) 2013*da0073e9SAndroid Build Coastguard Worker 2014*da0073e9SAndroid Build Coastguard Worker @torch.library.impl(f"{self.test_ns}::{name}", types) 2015*da0073e9SAndroid Build Coastguard Worker def f(x): 2016*da0073e9SAndroid Build Coastguard Worker x_np = x.cpu().numpy() 2017*da0073e9SAndroid Build Coastguard Worker y = torch.from_numpy(np.sin(x_np)) 2018*da0073e9SAndroid Build Coastguard Worker return y.to(device=x.device) 2019*da0073e9SAndroid Build Coastguard Worker 2020*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3, device=device) 2021*da0073e9SAndroid Build Coastguard Worker y = getattr(self.ns(), name)(x) 2022*da0073e9SAndroid Build Coastguard Worker assert torch.allclose(y, x.sin()) 2023*da0073e9SAndroid Build Coastguard Worker 2024*da0073e9SAndroid Build Coastguard Worker def test_impl_device_cpu(self): 2025*da0073e9SAndroid Build Coastguard Worker self._test_impl_device("foo1", "default", "cpu") 2026*da0073e9SAndroid Build Coastguard Worker self._test_impl_device("foo2", ["cpu"], "cpu") 2027*da0073e9SAndroid Build Coastguard Worker self._test_impl_device("foo3", ["cpu", "cuda"], "cpu") 2028*da0073e9SAndroid Build Coastguard Worker 2029*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDA, "requires cuda") 2030*da0073e9SAndroid Build Coastguard Worker def test_impl_device_cuda(self): 2031*da0073e9SAndroid Build Coastguard Worker self._test_impl_device("foo4", "default", "cuda") 2032*da0073e9SAndroid Build Coastguard Worker self._test_impl_device("foo5", ["cuda"], "cuda") 2033*da0073e9SAndroid Build Coastguard Worker self._test_impl_device("foo6", ["cpu", "cuda"], "cuda") 2034*da0073e9SAndroid Build Coastguard Worker 2035*da0073e9SAndroid Build Coastguard Worker def test_impl_device_function(self): 2036*da0073e9SAndroid Build Coastguard Worker lib = self.lib() 2037*da0073e9SAndroid Build Coastguard Worker torch.library.define(f"{self.test_ns}::foo", "(Tensor x) -> Tensor", lib=lib) 2038*da0073e9SAndroid Build Coastguard Worker 2039*da0073e9SAndroid Build Coastguard Worker def f(x): 2040*da0073e9SAndroid Build Coastguard Worker x_np = x.cpu().numpy() 2041*da0073e9SAndroid Build Coastguard Worker y = torch.from_numpy(np.sin(x_np)) 2042*da0073e9SAndroid Build Coastguard Worker return y.to(device=x.device) 2043*da0073e9SAndroid Build Coastguard Worker 2044*da0073e9SAndroid Build Coastguard Worker torch.library.impl(f"{self.test_ns}::foo", "default", f, lib=lib) 2045*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3) 2046*da0073e9SAndroid Build Coastguard Worker y = self.ns().foo(x) 2047*da0073e9SAndroid Build Coastguard Worker assert torch.allclose(y, x.sin()) 2048*da0073e9SAndroid Build Coastguard Worker 2049*da0073e9SAndroid Build Coastguard Worker def test_impl_device_invalid(self): 2050*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Expected one of cpu, cuda"): 2051*da0073e9SAndroid Build Coastguard Worker torch.library.impl("blah::blah", "somethingsomething") 2052*da0073e9SAndroid Build Coastguard Worker 2053*da0073e9SAndroid Build Coastguard Worker def test_autograd_function_backed_op(self): 2054*da0073e9SAndroid Build Coastguard Worker cpp_source = """ 2055*da0073e9SAndroid Build Coastguard Workerstruct CustomOpAutogradFunction : public torch::autograd::Function<CustomOpAutogradFunction> { 2056*da0073e9SAndroid Build Coastguard Worker static constexpr bool is_traceable = true; 2057*da0073e9SAndroid Build Coastguard Worker 2058*da0073e9SAndroid Build Coastguard Worker static torch::Tensor forward( 2059*da0073e9SAndroid Build Coastguard Worker torch::autograd::AutogradContext* ctx, 2060*da0073e9SAndroid Build Coastguard Worker const torch::Tensor& x) { 2061*da0073e9SAndroid Build Coastguard Worker return x; 2062*da0073e9SAndroid Build Coastguard Worker } 2063*da0073e9SAndroid Build Coastguard Worker 2064*da0073e9SAndroid Build Coastguard Worker static torch::autograd::variable_list backward( 2065*da0073e9SAndroid Build Coastguard Worker torch::autograd::AutogradContext *ctx, 2066*da0073e9SAndroid Build Coastguard Worker torch::autograd::variable_list grad_output) { 2067*da0073e9SAndroid Build Coastguard Worker return grad_output; 2068*da0073e9SAndroid Build Coastguard Worker } 2069*da0073e9SAndroid Build Coastguard Worker}; 2070*da0073e9SAndroid Build Coastguard Worker 2071*da0073e9SAndroid Build Coastguard Workertorch::Tensor custom_op_backed_by_autograd_fn(const torch::Tensor& x) { 2072*da0073e9SAndroid Build Coastguard Worker return CustomOpAutogradFunction::apply(x); 2073*da0073e9SAndroid Build Coastguard Worker} 2074*da0073e9SAndroid Build Coastguard Worker 2075*da0073e9SAndroid Build Coastguard WorkerTORCH_LIBRARY(mylib, m) { 2076*da0073e9SAndroid Build Coastguard Worker m.def("custom_op_backed_by_autograd_fn", custom_op_backed_by_autograd_fn); 2077*da0073e9SAndroid Build Coastguard Worker} 2078*da0073e9SAndroid Build Coastguard Worker """ 2079*da0073e9SAndroid Build Coastguard Worker 2080*da0073e9SAndroid Build Coastguard Worker module = torch.utils.cpp_extension.load_inline( 2081*da0073e9SAndroid Build Coastguard Worker name="mylib", 2082*da0073e9SAndroid Build Coastguard Worker cpp_sources=cpp_source, 2083*da0073e9SAndroid Build Coastguard Worker functions="custom_op_backed_by_autograd_fn", 2084*da0073e9SAndroid Build Coastguard Worker verbose=True, 2085*da0073e9SAndroid Build Coastguard Worker ) 2086*da0073e9SAndroid Build Coastguard Worker 2087*da0073e9SAndroid Build Coastguard Worker x = torch.ones(2, 2, requires_grad=True) 2088*da0073e9SAndroid Build Coastguard Worker temp = x.clone().detach() 2089*da0073e9SAndroid Build Coastguard Worker out = torch.ops.mylib.custom_op_backed_by_autograd_fn(x) 2090*da0073e9SAndroid Build Coastguard Worker loss = out.sum() 2091*da0073e9SAndroid Build Coastguard Worker loss.backward() 2092*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.grad, temp) 2093*da0073e9SAndroid Build Coastguard Worker 2094*da0073e9SAndroid Build Coastguard Worker 2095*da0073e9SAndroid Build Coastguard Workerdef op_with_incorrect_schema(testcase, name): 2096*da0073e9SAndroid Build Coastguard Worker lib = testcase.lib() 2097*da0073e9SAndroid Build Coastguard Worker lib.define(f"{name}(Tensor x) -> Tensor") 2098*da0073e9SAndroid Build Coastguard Worker qualname = f"{testcase.test_ns}::{name}" 2099*da0073e9SAndroid Build Coastguard Worker lib.impl(name, lambda x: x[:], "CompositeExplicitAutograd") 2100*da0073e9SAndroid Build Coastguard Worker return testcase.get_op(qualname) 2101*da0073e9SAndroid Build Coastguard Worker 2102*da0073e9SAndroid Build Coastguard Worker 2103*da0073e9SAndroid Build Coastguard Workerclass MiniOpTest(CustomOpTestCaseBase): 2104*da0073e9SAndroid Build Coastguard Worker test_ns = "mini_op_test" 2105*da0073e9SAndroid Build Coastguard Worker 2106*da0073e9SAndroid Build Coastguard Worker def _init_op_delayed_backward_error(self): 2107*da0073e9SAndroid Build Coastguard Worker name = "delayed_error" 2108*da0073e9SAndroid Build Coastguard Worker qualname = f"{self.test_ns}::{name}" 2109*da0073e9SAndroid Build Coastguard Worker lib = self.lib() 2110*da0073e9SAndroid Build Coastguard Worker lib.define(f"{name}(Tensor x) -> Tensor") 2111*da0073e9SAndroid Build Coastguard Worker lib.impl(name, lambda x: x.clone(), "CompositeExplicitAutograd") 2112*da0073e9SAndroid Build Coastguard Worker op = self.get_op(qualname) 2113*da0073e9SAndroid Build Coastguard Worker 2114*da0073e9SAndroid Build Coastguard Worker class Op(torch.autograd.Function): 2115*da0073e9SAndroid Build Coastguard Worker @staticmethod 2116*da0073e9SAndroid Build Coastguard Worker def forward(ctx, x): 2117*da0073e9SAndroid Build Coastguard Worker with torch._C._AutoDispatchBelowAutograd(): 2118*da0073e9SAndroid Build Coastguard Worker return op(x) 2119*da0073e9SAndroid Build Coastguard Worker 2120*da0073e9SAndroid Build Coastguard Worker @staticmethod 2121*da0073e9SAndroid Build Coastguard Worker def backward(ctx, grad): 2122*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError 2123*da0073e9SAndroid Build Coastguard Worker 2124*da0073e9SAndroid Build Coastguard Worker def autograd_impl(x): 2125*da0073e9SAndroid Build Coastguard Worker return Op.apply(x) 2126*da0073e9SAndroid Build Coastguard Worker 2127*da0073e9SAndroid Build Coastguard Worker lib.impl(name, autograd_impl, "Autograd") 2128*da0073e9SAndroid Build Coastguard Worker return op 2129*da0073e9SAndroid Build Coastguard Worker 2130*da0073e9SAndroid Build Coastguard Worker def _init_op_with_no_abstract_impl(self): 2131*da0073e9SAndroid Build Coastguard Worker name = "no_abstract" 2132*da0073e9SAndroid Build Coastguard Worker qualname = f"{self.test_ns}::{name}" 2133*da0073e9SAndroid Build Coastguard Worker lib = self.lib() 2134*da0073e9SAndroid Build Coastguard Worker lib.define(f"{name}(Tensor x) -> Tensor", tags=(torch.Tag.pt2_compliant_tag,)) 2135*da0073e9SAndroid Build Coastguard Worker lib.impl(name, lambda x: x.clone(), "CPU") 2136*da0073e9SAndroid Build Coastguard Worker return torch._library.utils.lookup_op(qualname) 2137*da0073e9SAndroid Build Coastguard Worker 2138*da0073e9SAndroid Build Coastguard Worker def setUp(self): 2139*da0073e9SAndroid Build Coastguard Worker super().setUp() 2140*da0073e9SAndroid Build Coastguard Worker self._op_with_no_abstract_impl = self._init_op_with_no_abstract_impl() 2141*da0073e9SAndroid Build Coastguard Worker self._op_delayed_backward_error = self._init_op_delayed_backward_error() 2142*da0073e9SAndroid Build Coastguard Worker 2143*da0073e9SAndroid Build Coastguard Worker @optests.dontGenerateOpCheckTests("Testing this API") 2144*da0073e9SAndroid Build Coastguard Worker def test_dont_generate(self): 2145*da0073e9SAndroid Build Coastguard Worker op = op_with_incorrect_schema(self, "incorrect_schema") 2146*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3) 2147*da0073e9SAndroid Build Coastguard Worker op(x) 2148*da0073e9SAndroid Build Coastguard Worker 2149*da0073e9SAndroid Build Coastguard Worker def test_mm(self): 2150*da0073e9SAndroid Build Coastguard Worker x = torch.randn(2, 3, requires_grad=True) 2151*da0073e9SAndroid Build Coastguard Worker y = torch.randn(3, 5) 2152*da0073e9SAndroid Build Coastguard Worker result = torch.ops.aten.mm.default(x, y) 2153*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, x @ y) 2154*da0073e9SAndroid Build Coastguard Worker 2155*da0073e9SAndroid Build Coastguard Worker def test_mm_meta(self): 2156*da0073e9SAndroid Build Coastguard Worker x = torch.randn(2, 3, requires_grad=True, device="meta") 2157*da0073e9SAndroid Build Coastguard Worker y = torch.randn(3, 5, device="meta") 2158*da0073e9SAndroid Build Coastguard Worker result = torch.ops.aten.mm.default(x, y) 2159*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result.shape, (x @ y).shape) 2160*da0073e9SAndroid Build Coastguard Worker 2161*da0073e9SAndroid Build Coastguard Worker def test_mm_fake(self): 2162*da0073e9SAndroid Build Coastguard Worker with torch._subclasses.fake_tensor.FakeTensorMode(): 2163*da0073e9SAndroid Build Coastguard Worker x = torch.randn(2, 3, requires_grad=True, device="cpu") 2164*da0073e9SAndroid Build Coastguard Worker y = torch.randn(3, 5, device="cpu") 2165*da0073e9SAndroid Build Coastguard Worker result = torch.ops.aten.mm.default(x, y) 2166*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result.shape, (x @ y).shape) 2167*da0073e9SAndroid Build Coastguard Worker 2168*da0073e9SAndroid Build Coastguard Worker def test_mm_errors(self): 2169*da0073e9SAndroid Build Coastguard Worker x = torch.randn(2, 3, requires_grad=True) 2170*da0073e9SAndroid Build Coastguard Worker y = torch.randn(4, 5) 2171*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "cannot be multiplied"): 2172*da0073e9SAndroid Build Coastguard Worker result = torch.ops.aten.mm.default(x, y) 2173*da0073e9SAndroid Build Coastguard Worker 2174*da0073e9SAndroid Build Coastguard Worker def test_nonzero(self): 2175*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([0, 1, 2, 0, 0]) 2176*da0073e9SAndroid Build Coastguard Worker y = torch.ops.aten.nonzero.default(x) 2177*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y, torch.tensor([[1], [2]])) 2178*da0073e9SAndroid Build Coastguard Worker 2179*da0073e9SAndroid Build Coastguard Worker def test_inplace(self): 2180*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3) 2181*da0073e9SAndroid Build Coastguard Worker x_clone = x.clone() 2182*da0073e9SAndroid Build Coastguard Worker y = torch.ops.aten.sin_(x) 2183*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x, x_clone.sin()) 2184*da0073e9SAndroid Build Coastguard Worker 2185*da0073e9SAndroid Build Coastguard Worker def test_incorrect_schema(self): 2186*da0073e9SAndroid Build Coastguard Worker op = op_with_incorrect_schema(self, "incorrect_schema") 2187*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3) 2188*da0073e9SAndroid Build Coastguard Worker op(x) 2189*da0073e9SAndroid Build Coastguard Worker 2190*da0073e9SAndroid Build Coastguard Worker def test_no_abstract(self): 2191*da0073e9SAndroid Build Coastguard Worker op = self._op_with_no_abstract_impl 2192*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3) 2193*da0073e9SAndroid Build Coastguard Worker op(x) 2194*da0073e9SAndroid Build Coastguard Worker 2195*da0073e9SAndroid Build Coastguard Worker def test_delayed_error(self): 2196*da0073e9SAndroid Build Coastguard Worker op = self._op_delayed_backward_error 2197*da0073e9SAndroid Build Coastguard Worker x = torch.randn([], requires_grad=True) 2198*da0073e9SAndroid Build Coastguard Worker y = op(x) 2199*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(NotImplementedError): 2200*da0073e9SAndroid Build Coastguard Worker y.sum().backward() 2201*da0073e9SAndroid Build Coastguard Worker 2202*da0073e9SAndroid Build Coastguard Worker def test_delayed_error_no_requires_grad(self): 2203*da0073e9SAndroid Build Coastguard Worker op = self._op_delayed_backward_error 2204*da0073e9SAndroid Build Coastguard Worker x = torch.randn([]) 2205*da0073e9SAndroid Build Coastguard Worker y = op(x) 2206*da0073e9SAndroid Build Coastguard Worker 2207*da0073e9SAndroid Build Coastguard Worker 2208*da0073e9SAndroid Build Coastguard Workerclass TestCustomOpAPI(TestCase): 2209*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") 2210*da0073e9SAndroid Build Coastguard Worker def test_basic(self): 2211*da0073e9SAndroid Build Coastguard Worker @torch.library.custom_op("_torch_testing::add", mutates_args=()) 2212*da0073e9SAndroid Build Coastguard Worker def add(x: Tensor, y: float) -> Tensor: 2213*da0073e9SAndroid Build Coastguard Worker x_np = x.numpy(force=True) 2214*da0073e9SAndroid Build Coastguard Worker out_np = x_np + y 2215*da0073e9SAndroid Build Coastguard Worker return torch.from_numpy(out_np).to(x.device) 2216*da0073e9SAndroid Build Coastguard Worker 2217*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3) 2218*da0073e9SAndroid Build Coastguard Worker y = 3.14 2219*da0073e9SAndroid Build Coastguard Worker z = add(x, y) 2220*da0073e9SAndroid Build Coastguard Worker self.assertEqual(z, x + y) 2221*da0073e9SAndroid Build Coastguard Worker 2222*da0073e9SAndroid Build Coastguard Worker cpu_called = False 2223*da0073e9SAndroid Build Coastguard Worker 2224*da0073e9SAndroid Build Coastguard Worker @add.register_kernel("cpu") 2225*da0073e9SAndroid Build Coastguard Worker def _(x, y): 2226*da0073e9SAndroid Build Coastguard Worker nonlocal cpu_called 2227*da0073e9SAndroid Build Coastguard Worker cpu_called = True 2228*da0073e9SAndroid Build Coastguard Worker x_np = x.numpy() 2229*da0073e9SAndroid Build Coastguard Worker out_np = x_np + y 2230*da0073e9SAndroid Build Coastguard Worker return torch.from_numpy(out_np) 2231*da0073e9SAndroid Build Coastguard Worker 2232*da0073e9SAndroid Build Coastguard Worker z = add(x, y) 2233*da0073e9SAndroid Build Coastguard Worker self.assertEqual(z, x + y) 2234*da0073e9SAndroid Build Coastguard Worker self.assertTrue(cpu_called) 2235*da0073e9SAndroid Build Coastguard Worker 2236*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") 2237*da0073e9SAndroid Build Coastguard Worker def test_no_grad_skips_autograd(self): 2238*da0073e9SAndroid Build Coastguard Worker @torch.library.custom_op("_torch_testing::add", mutates_args=()) 2239*da0073e9SAndroid Build Coastguard Worker def add(x: Tensor, y: float) -> Tensor: 2240*da0073e9SAndroid Build Coastguard Worker x_np = x.numpy(force=True) 2241*da0073e9SAndroid Build Coastguard Worker out_np = x_np + y 2242*da0073e9SAndroid Build Coastguard Worker return torch.from_numpy(out_np).to(x.device) 2243*da0073e9SAndroid Build Coastguard Worker 2244*da0073e9SAndroid Build Coastguard Worker called = 0 2245*da0073e9SAndroid Build Coastguard Worker 2246*da0073e9SAndroid Build Coastguard Worker def setup_context(ctx, inputs, output): 2247*da0073e9SAndroid Build Coastguard Worker nonlocal called 2248*da0073e9SAndroid Build Coastguard Worker called += 1 2249*da0073e9SAndroid Build Coastguard Worker 2250*da0073e9SAndroid Build Coastguard Worker def backward(ctx, grad): 2251*da0073e9SAndroid Build Coastguard Worker raise AssertionError("should not be reached") 2252*da0073e9SAndroid Build Coastguard Worker 2253*da0073e9SAndroid Build Coastguard Worker add.register_autograd(backward, setup_context=setup_context) 2254*da0073e9SAndroid Build Coastguard Worker 2255*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3, requires_grad=True) 2256*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 2257*da0073e9SAndroid Build Coastguard Worker y = add(x, 2.0) 2258*da0073e9SAndroid Build Coastguard Worker self.assertEqual(called, 0) 2259*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y, x + 2.0) 2260*da0073e9SAndroid Build Coastguard Worker 2261*da0073e9SAndroid Build Coastguard Worker x.requires_grad_(False) 2262*da0073e9SAndroid Build Coastguard Worker y = add(x, 2.0) 2263*da0073e9SAndroid Build Coastguard Worker self.assertEqual(called, 0) 2264*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y, x + 2.0) 2265*da0073e9SAndroid Build Coastguard Worker 2266*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3, requires_grad=True) 2267*da0073e9SAndroid Build Coastguard Worker y = add(x, 2.0) 2268*da0073e9SAndroid Build Coastguard Worker self.assertEqual(called, 1) 2269*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y, x + 2.0) 2270*da0073e9SAndroid Build Coastguard Worker 2271*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") 2272*da0073e9SAndroid Build Coastguard Worker def test_manual_schema(self): 2273*da0073e9SAndroid Build Coastguard Worker @torch.library.custom_op( 2274*da0073e9SAndroid Build Coastguard Worker "_torch_testing::add", 2275*da0073e9SAndroid Build Coastguard Worker mutates_args=(), 2276*da0073e9SAndroid Build Coastguard Worker schema="(Tensor x, float y) -> Tensor", 2277*da0073e9SAndroid Build Coastguard Worker ) 2278*da0073e9SAndroid Build Coastguard Worker def add(x, y): 2279*da0073e9SAndroid Build Coastguard Worker x_np = x.numpy(force=True) 2280*da0073e9SAndroid Build Coastguard Worker out_np = x_np + y 2281*da0073e9SAndroid Build Coastguard Worker return torch.from_numpy(out_np).to(x.device) 2282*da0073e9SAndroid Build Coastguard Worker 2283*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3) 2284*da0073e9SAndroid Build Coastguard Worker y = 3.14 2285*da0073e9SAndroid Build Coastguard Worker z = add(x, y) 2286*da0073e9SAndroid Build Coastguard Worker self.assertEqual(z, x + y) 2287*da0073e9SAndroid Build Coastguard Worker 2288*da0073e9SAndroid Build Coastguard Worker @torch.library.custom_op( 2289*da0073e9SAndroid Build Coastguard Worker "_torch_testing::sin_", 2290*da0073e9SAndroid Build Coastguard Worker mutates_args=["x"], 2291*da0073e9SAndroid Build Coastguard Worker schema="(Tensor(a!) x) -> ()", 2292*da0073e9SAndroid Build Coastguard Worker ) 2293*da0073e9SAndroid Build Coastguard Worker def sin_(x): 2294*da0073e9SAndroid Build Coastguard Worker x_np = x.numpy() 2295*da0073e9SAndroid Build Coastguard Worker np.sin(x_np, out=x_np) 2296*da0073e9SAndroid Build Coastguard Worker 2297*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3) 2298*da0073e9SAndroid Build Coastguard Worker expected = x.sin() 2299*da0073e9SAndroid Build Coastguard Worker sin_(x) 2300*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x, expected) 2301*da0073e9SAndroid Build Coastguard Worker 2302*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") 2303*da0073e9SAndroid Build Coastguard Worker def test_kwarg_only_tensors(self): 2304*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(NotImplementedError, "kwarg-only Tensor args"): 2305*da0073e9SAndroid Build Coastguard Worker 2306*da0073e9SAndroid Build Coastguard Worker @torch.library.custom_op("_torch_testing::foo", mutates_args=()) 2307*da0073e9SAndroid Build Coastguard Worker def foo(x: Tensor, *, y: int, z: Tensor) -> Tensor: 2308*da0073e9SAndroid Build Coastguard Worker pass 2309*da0073e9SAndroid Build Coastguard Worker 2310*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(NotImplementedError, "kwarg-only Tensor args"): 2311*da0073e9SAndroid Build Coastguard Worker 2312*da0073e9SAndroid Build Coastguard Worker @torch.library.custom_op("_torch_testing::foo", mutates_args=()) 2313*da0073e9SAndroid Build Coastguard Worker def foo2(x: Tensor, *, y: int, z: Optional[Tensor]) -> Tensor: 2314*da0073e9SAndroid Build Coastguard Worker pass 2315*da0073e9SAndroid Build Coastguard Worker 2316*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(NotImplementedError, "kwarg-only Tensor args"): 2317*da0073e9SAndroid Build Coastguard Worker 2318*da0073e9SAndroid Build Coastguard Worker @torch.library.custom_op("_torch_testing::foo", mutates_args=()) 2319*da0073e9SAndroid Build Coastguard Worker def foo3(x: Tensor, *, y: int, z: List[Tensor]) -> Tensor: 2320*da0073e9SAndroid Build Coastguard Worker pass 2321*da0073e9SAndroid Build Coastguard Worker 2322*da0073e9SAndroid Build Coastguard Worker with torch.library._scoped_library("_torch_testing", "FRAGMENT") as lib: 2323*da0073e9SAndroid Build Coastguard Worker lib.define("foo(Tensor x, *, Tensor y) -> Tensor") 2324*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(NotImplementedError, "kwarg-only Tensor args"): 2325*da0073e9SAndroid Build Coastguard Worker torch.library.register_autograd( 2326*da0073e9SAndroid Build Coastguard Worker "_torch_testing::foo", 2327*da0073e9SAndroid Build Coastguard Worker lambda grad: grad, 2328*da0073e9SAndroid Build Coastguard Worker setup_context=lambda ctx, inputs, keyword_only_inputs, output: None, 2329*da0073e9SAndroid Build Coastguard Worker ) 2330*da0073e9SAndroid Build Coastguard Worker 2331*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(NotImplementedError, "kwarg-only Tensor args"): 2332*da0073e9SAndroid Build Coastguard Worker torch.library.register_vmap( 2333*da0073e9SAndroid Build Coastguard Worker "_torch_testing::foo", 2334*da0073e9SAndroid Build Coastguard Worker lambda info, in_dims, x, *, y: (x, 0), 2335*da0073e9SAndroid Build Coastguard Worker ) 2336*da0073e9SAndroid Build Coastguard Worker 2337*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") 2338*da0073e9SAndroid Build Coastguard Worker def test_register_autograd_kwargonly_low_level(self): 2339*da0073e9SAndroid Build Coastguard Worker with torch.library._scoped_library("_torch_testing", "FRAGMENT") as lib: 2340*da0073e9SAndroid Build Coastguard Worker lib.define("foo(Tensor x, *, float y) -> Tensor") 2341*da0073e9SAndroid Build Coastguard Worker called = False 2342*da0073e9SAndroid Build Coastguard Worker 2343*da0073e9SAndroid Build Coastguard Worker def foo_impl(x, *, y): 2344*da0073e9SAndroid Build Coastguard Worker return x * y 2345*da0073e9SAndroid Build Coastguard Worker 2346*da0073e9SAndroid Build Coastguard Worker lib.impl("foo", foo_impl, "CPU") 2347*da0073e9SAndroid Build Coastguard Worker 2348*da0073e9SAndroid Build Coastguard Worker def backward(ctx, grad): 2349*da0073e9SAndroid Build Coastguard Worker nonlocal called 2350*da0073e9SAndroid Build Coastguard Worker called = True 2351*da0073e9SAndroid Build Coastguard Worker return grad * ctx.y 2352*da0073e9SAndroid Build Coastguard Worker 2353*da0073e9SAndroid Build Coastguard Worker def setup_context(ctx, inputs, keyword_only_inputs, output): 2354*da0073e9SAndroid Build Coastguard Worker assert tuple(keyword_only_inputs.keys()) == ("y",) 2355*da0073e9SAndroid Build Coastguard Worker ctx.y = keyword_only_inputs["y"] 2356*da0073e9SAndroid Build Coastguard Worker 2357*da0073e9SAndroid Build Coastguard Worker torch.library.register_autograd( 2358*da0073e9SAndroid Build Coastguard Worker "_torch_testing::foo", backward, setup_context=setup_context, lib=lib 2359*da0073e9SAndroid Build Coastguard Worker ) 2360*da0073e9SAndroid Build Coastguard Worker 2361*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3, requires_grad=True) 2362*da0073e9SAndroid Build Coastguard Worker torch.ops._torch_testing.foo(x, y=3.14).sum().backward() 2363*da0073e9SAndroid Build Coastguard Worker self.assertTrue(called) 2364*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.grad, torch.tensor([3.14, 3.14, 3.14])) 2365*da0073e9SAndroid Build Coastguard Worker 2366*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") 2367*da0073e9SAndroid Build Coastguard Worker def test_register_autograd_defaults(self): 2368*da0073e9SAndroid Build Coastguard Worker with torch.library._scoped_library("_torch_testing", "FRAGMENT") as lib: 2369*da0073e9SAndroid Build Coastguard Worker lib.define("foo(Tensor w, int x = 2, *, int y = 3, int z) -> Tensor") 2370*da0073e9SAndroid Build Coastguard Worker 2371*da0073e9SAndroid Build Coastguard Worker def foo_impl(w, x=2, *, y=3, z): 2372*da0073e9SAndroid Build Coastguard Worker return w * x * y * z 2373*da0073e9SAndroid Build Coastguard Worker 2374*da0073e9SAndroid Build Coastguard Worker lib.impl("foo", foo_impl, "CPU") 2375*da0073e9SAndroid Build Coastguard Worker 2376*da0073e9SAndroid Build Coastguard Worker called = False 2377*da0073e9SAndroid Build Coastguard Worker 2378*da0073e9SAndroid Build Coastguard Worker def backward(ctx, grad): 2379*da0073e9SAndroid Build Coastguard Worker nonlocal called 2380*da0073e9SAndroid Build Coastguard Worker called = True 2381*da0073e9SAndroid Build Coastguard Worker return grad * ctx.c 2382*da0073e9SAndroid Build Coastguard Worker 2383*da0073e9SAndroid Build Coastguard Worker def setup_context(ctx, inputs, keyword_only_inputs, output): 2384*da0073e9SAndroid Build Coastguard Worker assert len(inputs) == 2 2385*da0073e9SAndroid Build Coastguard Worker assert inputs[1] == 2 2386*da0073e9SAndroid Build Coastguard Worker assert keyword_only_inputs == {"y": 3, "z": 42} 2387*da0073e9SAndroid Build Coastguard Worker ctx.c = keyword_only_inputs["y"] * keyword_only_inputs["z"] * inputs[1] 2388*da0073e9SAndroid Build Coastguard Worker 2389*da0073e9SAndroid Build Coastguard Worker torch.library.register_autograd( 2390*da0073e9SAndroid Build Coastguard Worker "_torch_testing::foo", backward, setup_context=setup_context, lib=lib 2391*da0073e9SAndroid Build Coastguard Worker ) 2392*da0073e9SAndroid Build Coastguard Worker 2393*da0073e9SAndroid Build Coastguard Worker w = torch.randn(3, requires_grad=True) 2394*da0073e9SAndroid Build Coastguard Worker torch.ops._torch_testing.foo(w, z=42).sum().backward() 2395*da0073e9SAndroid Build Coastguard Worker self.assertTrue(called) 2396*da0073e9SAndroid Build Coastguard Worker self.assertEqual(w.grad, torch.full_like(w, 2 * 3 * 42)) 2397*da0073e9SAndroid Build Coastguard Worker 2398*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") 2399*da0073e9SAndroid Build Coastguard Worker def test_manual_schema_error(self): 2400*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, "the op mutates {'x'}"): 2401*da0073e9SAndroid Build Coastguard Worker 2402*da0073e9SAndroid Build Coastguard Worker @torch.library.custom_op( 2403*da0073e9SAndroid Build Coastguard Worker "_torch_testing::sin_", 2404*da0073e9SAndroid Build Coastguard Worker mutates_args=(), 2405*da0073e9SAndroid Build Coastguard Worker schema="(Tensor(a!) x) -> ()", 2406*da0073e9SAndroid Build Coastguard Worker ) 2407*da0073e9SAndroid Build Coastguard Worker def sin_(x): 2408*da0073e9SAndroid Build Coastguard Worker x_np = x.numpy() 2409*da0073e9SAndroid Build Coastguard Worker np.sin(x_np, out=x_np) 2410*da0073e9SAndroid Build Coastguard Worker 2411*da0073e9SAndroid Build Coastguard Worker def test_supports_tensorlist(self): 2412*da0073e9SAndroid Build Coastguard Worker @torch._library.autograd.supports_tensorlist 2413*da0073e9SAndroid Build Coastguard Worker class Stack(torch.autograd.Function): 2414*da0073e9SAndroid Build Coastguard Worker @staticmethod 2415*da0073e9SAndroid Build Coastguard Worker def forward(ctx, xs): 2416*da0073e9SAndroid Build Coastguard Worker ctx.num_xs = len(xs) 2417*da0073e9SAndroid Build Coastguard Worker return torch.stack(xs) 2418*da0073e9SAndroid Build Coastguard Worker 2419*da0073e9SAndroid Build Coastguard Worker @staticmethod 2420*da0073e9SAndroid Build Coastguard Worker def backward(ctx, grad): 2421*da0073e9SAndroid Build Coastguard Worker expected = ([True] * ctx.num_xs,) 2422*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ctx.needs_input_grad, expected) 2423*da0073e9SAndroid Build Coastguard Worker return list(grad.unbind(0)) 2424*da0073e9SAndroid Build Coastguard Worker 2425*da0073e9SAndroid Build Coastguard Worker # call two applys, do a backward on the first 2426*da0073e9SAndroid Build Coastguard Worker def t(): 2427*da0073e9SAndroid Build Coastguard Worker return torch.randn([], requires_grad=True) 2428*da0073e9SAndroid Build Coastguard Worker 2429*da0073e9SAndroid Build Coastguard Worker xs0 = [t(), t(), t()] 2430*da0073e9SAndroid Build Coastguard Worker xs1 = [t(), t(), t(), t()] 2431*da0073e9SAndroid Build Coastguard Worker y0 = Stack.apply(xs0) 2432*da0073e9SAndroid Build Coastguard Worker y1 = Stack.apply(xs1) 2433*da0073e9SAndroid Build Coastguard Worker grads = torch.autograd.grad(y0.sum(), xs0) 2434*da0073e9SAndroid Build Coastguard Worker self.assertEqual(grads, [torch.tensor(1.0) for _ in range(3)]) 2435*da0073e9SAndroid Build Coastguard Worker 2436*da0073e9SAndroid Build Coastguard Worker # call one apply, do multiple backwards 2437*da0073e9SAndroid Build Coastguard Worker xs = [t(), t(), t()] 2438*da0073e9SAndroid Build Coastguard Worker y = Stack.apply(xs) 2439*da0073e9SAndroid Build Coastguard Worker _ = torch.autograd.grad(y.sum(), xs, retain_graph=True) 2440*da0073e9SAndroid Build Coastguard Worker _ = torch.autograd.grad(y.sum(), xs, retain_graph=True) 2441*da0073e9SAndroid Build Coastguard Worker grads = torch.autograd.grad(y.sum(), xs, retain_graph=True) 2442*da0073e9SAndroid Build Coastguard Worker self.assertEqual(grads, [torch.tensor(1.0) for _ in range(3)]) 2443*da0073e9SAndroid Build Coastguard Worker 2444*da0073e9SAndroid Build Coastguard Worker # error: on access forward, backward directly 2445*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(NotImplementedError, "Function.forward directly"): 2446*da0073e9SAndroid Build Coastguard Worker Stack.forward(None, xs) 2447*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(NotImplementedError, "Function.backward directly"): 2448*da0073e9SAndroid Build Coastguard Worker Stack.backward(None, xs) 2449*da0073e9SAndroid Build Coastguard Worker 2450*da0073e9SAndroid Build Coastguard Worker # the recursive case 2451*da0073e9SAndroid Build Coastguard Worker @torch._library.autograd.supports_tensorlist 2452*da0073e9SAndroid Build Coastguard Worker class Foo(torch.autograd.Function): 2453*da0073e9SAndroid Build Coastguard Worker @staticmethod 2454*da0073e9SAndroid Build Coastguard Worker def forward(ctx, xs): 2455*da0073e9SAndroid Build Coastguard Worker if len(xs) > 1: 2456*da0073e9SAndroid Build Coastguard Worker return Foo.apply(xs[1:]) 2457*da0073e9SAndroid Build Coastguard Worker ctx.len_xs = len(xs) 2458*da0073e9SAndroid Build Coastguard Worker return xs[0].sin() 2459*da0073e9SAndroid Build Coastguard Worker 2460*da0073e9SAndroid Build Coastguard Worker @staticmethod 2461*da0073e9SAndroid Build Coastguard Worker def backward(ctx, grad): 2462*da0073e9SAndroid Build Coastguard Worker result = [None] * ctx.len_xs 2463*da0073e9SAndroid Build Coastguard Worker result[-1] = grad.cos() 2464*da0073e9SAndroid Build Coastguard Worker return result 2465*da0073e9SAndroid Build Coastguard Worker 2466*da0073e9SAndroid Build Coastguard Worker # should work 2467*da0073e9SAndroid Build Coastguard Worker result = Foo.apply(xs) 2468*da0073e9SAndroid Build Coastguard Worker expected = xs[-1].sin() 2469*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, expected) 2470*da0073e9SAndroid Build Coastguard Worker 2471*da0073e9SAndroid Build Coastguard Worker # recursive on backward 2472*da0073e9SAndroid Build Coastguard Worker @torch._library.autograd.supports_tensorlist 2473*da0073e9SAndroid Build Coastguard Worker class Bar(torch.autograd.Function): 2474*da0073e9SAndroid Build Coastguard Worker @staticmethod 2475*da0073e9SAndroid Build Coastguard Worker def forward(ctx, xs): 2476*da0073e9SAndroid Build Coastguard Worker return [xs[i] + i for i in range(len(xs))] 2477*da0073e9SAndroid Build Coastguard Worker 2478*da0073e9SAndroid Build Coastguard Worker @staticmethod 2479*da0073e9SAndroid Build Coastguard Worker def backward(ctx, grads): 2480*da0073e9SAndroid Build Coastguard Worker f1 = Bar.apply(grads[:2]) 2481*da0073e9SAndroid Build Coastguard Worker f2 = Bar.apply(grads[2:]) 2482*da0073e9SAndroid Build Coastguard Worker return f1 + f2 2483*da0073e9SAndroid Build Coastguard Worker 2484*da0073e9SAndroid Build Coastguard Worker xs = [torch.tensor(0.0, requires_grad=True) for _ in range(5)] 2485*da0073e9SAndroid Build Coastguard Worker ys = Bar.apply(xs) 2486*da0073e9SAndroid Build Coastguard Worker sum(ys).backward() 2487*da0073e9SAndroid Build Coastguard Worker result = [xi.grad for xi in xs] 2488*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, torch.tensor([1.0, 2, 1, 2, 3]).unbind(0)) 2489*da0073e9SAndroid Build Coastguard Worker 2490*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") 2491*da0073e9SAndroid Build Coastguard Worker def test_default_values(self): 2492*da0073e9SAndroid Build Coastguard Worker defaults = [] 2493*da0073e9SAndroid Build Coastguard Worker 2494*da0073e9SAndroid Build Coastguard Worker @torch.library.custom_op("_torch_testing::f", mutates_args=()) 2495*da0073e9SAndroid Build Coastguard Worker def f( 2496*da0073e9SAndroid Build Coastguard Worker x: Tensor, 2497*da0073e9SAndroid Build Coastguard Worker a: Optional[int] = None, 2498*da0073e9SAndroid Build Coastguard Worker b: float = 3.14, 2499*da0073e9SAndroid Build Coastguard Worker c: bool = True, 2500*da0073e9SAndroid Build Coastguard Worker d: int = 3, 2501*da0073e9SAndroid Build Coastguard Worker e: str = "foo", 2502*da0073e9SAndroid Build Coastguard Worker f: torch.dtype = torch.float, 2503*da0073e9SAndroid Build Coastguard Worker g: torch.dtype = torch.float32, 2504*da0073e9SAndroid Build Coastguard Worker h: torch.dtype = torch.int, 2505*da0073e9SAndroid Build Coastguard Worker i: torch.device = torch.device("cpu:0"), 2506*da0073e9SAndroid Build Coastguard Worker j: torch.device = "cpu", 2507*da0073e9SAndroid Build Coastguard Worker ) -> Tensor: 2508*da0073e9SAndroid Build Coastguard Worker defaults.extend([a, b, c, d, e, f, g, h, i, j]) 2509*da0073e9SAndroid Build Coastguard Worker return x.clone() 2510*da0073e9SAndroid Build Coastguard Worker 2511*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3) 2512*da0073e9SAndroid Build Coastguard Worker f(x) 2513*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 2514*da0073e9SAndroid Build Coastguard Worker defaults, 2515*da0073e9SAndroid Build Coastguard Worker [ 2516*da0073e9SAndroid Build Coastguard Worker None, 2517*da0073e9SAndroid Build Coastguard Worker 3.14, 2518*da0073e9SAndroid Build Coastguard Worker True, 2519*da0073e9SAndroid Build Coastguard Worker 3, 2520*da0073e9SAndroid Build Coastguard Worker "foo", 2521*da0073e9SAndroid Build Coastguard Worker torch.float, 2522*da0073e9SAndroid Build Coastguard Worker torch.float32, 2523*da0073e9SAndroid Build Coastguard Worker torch.int, 2524*da0073e9SAndroid Build Coastguard Worker torch.device("cpu:0"), 2525*da0073e9SAndroid Build Coastguard Worker "cpu", 2526*da0073e9SAndroid Build Coastguard Worker ], 2527*da0073e9SAndroid Build Coastguard Worker ) 2528*da0073e9SAndroid Build Coastguard Worker default_values = [ 2529*da0073e9SAndroid Build Coastguard Worker arg.default_value 2530*da0073e9SAndroid Build Coastguard Worker for arg in torch.ops._torch_testing.f.default._schema.arguments 2531*da0073e9SAndroid Build Coastguard Worker ] 2532*da0073e9SAndroid Build Coastguard Worker # enum values taken from c10/core/ScalarType.h 2533*da0073e9SAndroid Build Coastguard Worker type_enum = { 2534*da0073e9SAndroid Build Coastguard Worker "float": 6, 2535*da0073e9SAndroid Build Coastguard Worker "int": 3, 2536*da0073e9SAndroid Build Coastguard Worker } 2537*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 2538*da0073e9SAndroid Build Coastguard Worker default_values, 2539*da0073e9SAndroid Build Coastguard Worker [ 2540*da0073e9SAndroid Build Coastguard Worker None, 2541*da0073e9SAndroid Build Coastguard Worker None, 2542*da0073e9SAndroid Build Coastguard Worker 3.14, 2543*da0073e9SAndroid Build Coastguard Worker True, 2544*da0073e9SAndroid Build Coastguard Worker 3, 2545*da0073e9SAndroid Build Coastguard Worker "foo", 2546*da0073e9SAndroid Build Coastguard Worker type_enum["float"], 2547*da0073e9SAndroid Build Coastguard Worker type_enum["float"], 2548*da0073e9SAndroid Build Coastguard Worker type_enum["int"], 2549*da0073e9SAndroid Build Coastguard Worker torch.device("cpu:0"), 2550*da0073e9SAndroid Build Coastguard Worker torch.device("cpu"), 2551*da0073e9SAndroid Build Coastguard Worker ], 2552*da0073e9SAndroid Build Coastguard Worker ) 2553*da0073e9SAndroid Build Coastguard Worker 2554*da0073e9SAndroid Build Coastguard Worker def test_mutated_error(self): 2555*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 2556*da0073e9SAndroid Build Coastguard Worker ValueError, r".*{'y'} in mutates_args were not found" 2557*da0073e9SAndroid Build Coastguard Worker ): 2558*da0073e9SAndroid Build Coastguard Worker 2559*da0073e9SAndroid Build Coastguard Worker @torch.library.custom_op( 2560*da0073e9SAndroid Build Coastguard Worker "_torch_testing::numpy_sin_inplace", 2561*da0073e9SAndroid Build Coastguard Worker mutates_args={"y"}, 2562*da0073e9SAndroid Build Coastguard Worker device_types="cpu", 2563*da0073e9SAndroid Build Coastguard Worker ) 2564*da0073e9SAndroid Build Coastguard Worker def numpy_sin_inplace(x: Tensor) -> None: 2565*da0073e9SAndroid Build Coastguard Worker x_np = x.numpy() 2566*da0073e9SAndroid Build Coastguard Worker np.sin(x_np, out=x_np) 2567*da0073e9SAndroid Build Coastguard Worker 2568*da0073e9SAndroid Build Coastguard Worker def test_mutated(self): 2569*da0073e9SAndroid Build Coastguard Worker @torch.library.custom_op( 2570*da0073e9SAndroid Build Coastguard Worker "_torch_testing::numpy_sin_inplace", mutates_args={"x"}, device_types="cpu" 2571*da0073e9SAndroid Build Coastguard Worker ) 2572*da0073e9SAndroid Build Coastguard Worker def numpy_sin_inplace(x: Tensor) -> None: 2573*da0073e9SAndroid Build Coastguard Worker x_np = x.numpy() 2574*da0073e9SAndroid Build Coastguard Worker np.sin(x_np, out=x_np) 2575*da0073e9SAndroid Build Coastguard Worker 2576*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3) 2577*da0073e9SAndroid Build Coastguard Worker version = x._version 2578*da0073e9SAndroid Build Coastguard Worker expected = x.sin() 2579*da0073e9SAndroid Build Coastguard Worker numpy_sin_inplace(x) 2580*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x, expected) 2581*da0073e9SAndroid Build Coastguard Worker self.assertGreater(x._version, version) 2582*da0073e9SAndroid Build Coastguard Worker 2583*da0073e9SAndroid Build Coastguard Worker @torch.library.custom_op("_torch_testing::f", mutates_args={"y", "z", "w"}) 2584*da0073e9SAndroid Build Coastguard Worker def f( 2585*da0073e9SAndroid Build Coastguard Worker x: Tensor, y: Optional[Tensor], z: List[Tensor], w: List[Optional[Tensor]] 2586*da0073e9SAndroid Build Coastguard Worker ) -> None: 2587*da0073e9SAndroid Build Coastguard Worker return 2588*da0073e9SAndroid Build Coastguard Worker 2589*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3) 2590*da0073e9SAndroid Build Coastguard Worker y = torch.randn(3) 2591*da0073e9SAndroid Build Coastguard Worker z = [torch.randn(3), torch.randn(3)] 2592*da0073e9SAndroid Build Coastguard Worker w = [torch.randn(3), None, torch.randn(3)] 2593*da0073e9SAndroid Build Coastguard Worker initial_versions = pytree.tree_map_only( 2594*da0073e9SAndroid Build Coastguard Worker torch.Tensor, lambda x: x._version, (x, y, z, w) 2595*da0073e9SAndroid Build Coastguard Worker ) 2596*da0073e9SAndroid Build Coastguard Worker f(x, y, z, w) 2597*da0073e9SAndroid Build Coastguard Worker new_versions = pytree.tree_map_only( 2598*da0073e9SAndroid Build Coastguard Worker torch.Tensor, lambda x: x._version, (x, y, z, w) 2599*da0073e9SAndroid Build Coastguard Worker ) 2600*da0073e9SAndroid Build Coastguard Worker 2601*da0073e9SAndroid Build Coastguard Worker self.assertEqual(initial_versions[0], new_versions[0]) 2602*da0073e9SAndroid Build Coastguard Worker initial_versions, _ = pytree.tree_flatten(initial_versions[1:]) 2603*da0073e9SAndroid Build Coastguard Worker new_versions, _ = pytree.tree_flatten(new_versions[1:]) 2604*da0073e9SAndroid Build Coastguard Worker for prev, after in zip(initial_versions, new_versions): 2605*da0073e9SAndroid Build Coastguard Worker if prev is None and after is None: 2606*da0073e9SAndroid Build Coastguard Worker continue 2607*da0073e9SAndroid Build Coastguard Worker self.assertGreater(after, prev) 2608*da0073e9SAndroid Build Coastguard Worker 2609*da0073e9SAndroid Build Coastguard Worker def test_mutated_unknown(self): 2610*da0073e9SAndroid Build Coastguard Worker @torch.library.custom_op( 2611*da0073e9SAndroid Build Coastguard Worker "_torch_testing::f", mutates_args="unknown", device_types="cpu" 2612*da0073e9SAndroid Build Coastguard Worker ) 2613*da0073e9SAndroid Build Coastguard Worker def f(x: Tensor) -> None: 2614*da0073e9SAndroid Build Coastguard Worker x_np = x.numpy() 2615*da0073e9SAndroid Build Coastguard Worker np.sin(x_np, out=x_np) 2616*da0073e9SAndroid Build Coastguard Worker 2617*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3) 2618*da0073e9SAndroid Build Coastguard Worker version = x._version 2619*da0073e9SAndroid Build Coastguard Worker expected = x.sin() 2620*da0073e9SAndroid Build Coastguard Worker f(x) 2621*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x, expected) 2622*da0073e9SAndroid Build Coastguard Worker self.assertGreater(x._version, version) 2623*da0073e9SAndroid Build Coastguard Worker 2624*da0073e9SAndroid Build Coastguard Worker @torch.library.custom_op("_torch_testing::f2", mutates_args="unknown") 2625*da0073e9SAndroid Build Coastguard Worker def f2( 2626*da0073e9SAndroid Build Coastguard Worker x: Tensor, y: Optional[Tensor], z: List[Tensor], w: List[Optional[Tensor]] 2627*da0073e9SAndroid Build Coastguard Worker ) -> None: 2628*da0073e9SAndroid Build Coastguard Worker return 2629*da0073e9SAndroid Build Coastguard Worker 2630*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3) 2631*da0073e9SAndroid Build Coastguard Worker y = torch.randn(3) 2632*da0073e9SAndroid Build Coastguard Worker z = [torch.randn(3), torch.randn(3)] 2633*da0073e9SAndroid Build Coastguard Worker w = [torch.randn(3), None, torch.randn(3)] 2634*da0073e9SAndroid Build Coastguard Worker initial_versions = pytree.tree_map_only( 2635*da0073e9SAndroid Build Coastguard Worker torch.Tensor, lambda x: x._version, (x, y, z, w) 2636*da0073e9SAndroid Build Coastguard Worker ) 2637*da0073e9SAndroid Build Coastguard Worker f2(x, y, z, w) 2638*da0073e9SAndroid Build Coastguard Worker new_versions = pytree.tree_map_only( 2639*da0073e9SAndroid Build Coastguard Worker torch.Tensor, lambda x: x._version, (x, y, z, w) 2640*da0073e9SAndroid Build Coastguard Worker ) 2641*da0073e9SAndroid Build Coastguard Worker 2642*da0073e9SAndroid Build Coastguard Worker initial_versions, _ = pytree.tree_flatten(initial_versions) 2643*da0073e9SAndroid Build Coastguard Worker new_versions, _ = pytree.tree_flatten(new_versions) 2644*da0073e9SAndroid Build Coastguard Worker for prev, after in zip(initial_versions, new_versions): 2645*da0073e9SAndroid Build Coastguard Worker if prev is None and after is None: 2646*da0073e9SAndroid Build Coastguard Worker continue 2647*da0073e9SAndroid Build Coastguard Worker self.assertGreater(after, prev) 2648*da0073e9SAndroid Build Coastguard Worker 2649*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, "string"): 2650*da0073e9SAndroid Build Coastguard Worker 2651*da0073e9SAndroid Build Coastguard Worker @torch.library.custom_op("_torch_testing::f3", mutates_args="x") 2652*da0073e9SAndroid Build Coastguard Worker def f3(x: Tensor) -> None: 2653*da0073e9SAndroid Build Coastguard Worker return 2654*da0073e9SAndroid Build Coastguard Worker 2655*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") 2656*da0073e9SAndroid Build Coastguard Worker def test_library_register_torch_dispatch_rule_subclass(self): 2657*da0073e9SAndroid Build Coastguard Worker from torch.testing._internal.two_tensor import TwoTensor 2658*da0073e9SAndroid Build Coastguard Worker 2659*da0073e9SAndroid Build Coastguard Worker @torch.library.custom_op("mylib::foo", mutates_args={}) 2660*da0073e9SAndroid Build Coastguard Worker def f(x: torch.Tensor) -> torch.Tensor: 2661*da0073e9SAndroid Build Coastguard Worker return x.sin() 2662*da0073e9SAndroid Build Coastguard Worker 2663*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3) 2664*da0073e9SAndroid Build Coastguard Worker y = torch.randn(3) 2665*da0073e9SAndroid Build Coastguard Worker z = TwoTensor(x, y) 2666*da0073e9SAndroid Build Coastguard Worker 2667*da0073e9SAndroid Build Coastguard Worker with torch.library._scoped_library("mylib", "FRAGMENT") as m: 2668*da0073e9SAndroid Build Coastguard Worker called = 0 2669*da0073e9SAndroid Build Coastguard Worker 2670*da0073e9SAndroid Build Coastguard Worker def TwoTensor_foo(cls, func, types, args, kwargs): 2671*da0073e9SAndroid Build Coastguard Worker nonlocal called 2672*da0073e9SAndroid Build Coastguard Worker assert cls is TwoTensor 2673*da0073e9SAndroid Build Coastguard Worker called += 1 2674*da0073e9SAndroid Build Coastguard Worker return x.sin() 2675*da0073e9SAndroid Build Coastguard Worker 2676*da0073e9SAndroid Build Coastguard Worker m._register_torch_dispatch_rule("foo", TwoTensor, TwoTensor_foo) 2677*da0073e9SAndroid Build Coastguard Worker 2678*da0073e9SAndroid Build Coastguard Worker out = f(z) 2679*da0073e9SAndroid Build Coastguard Worker out2 = z.cos() 2680*da0073e9SAndroid Build Coastguard Worker 2681*da0073e9SAndroid Build Coastguard Worker self.assertEqual(called, 1) 2682*da0073e9SAndroid Build Coastguard Worker 2683*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") 2684*da0073e9SAndroid Build Coastguard Worker def test_library_register_torch_dispatch_rule_mode(self): 2685*da0073e9SAndroid Build Coastguard Worker from torch.testing._internal.two_tensor import TwoTensorMode 2686*da0073e9SAndroid Build Coastguard Worker 2687*da0073e9SAndroid Build Coastguard Worker @torch.library.custom_op("mylib::foo", mutates_args={}) 2688*da0073e9SAndroid Build Coastguard Worker def f(x: torch.Tensor) -> torch.Tensor: 2689*da0073e9SAndroid Build Coastguard Worker return x.sin() 2690*da0073e9SAndroid Build Coastguard Worker 2691*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3) 2692*da0073e9SAndroid Build Coastguard Worker 2693*da0073e9SAndroid Build Coastguard Worker with torch.library._scoped_library("mylib", "FRAGMENT") as m: 2694*da0073e9SAndroid Build Coastguard Worker called = 0 2695*da0073e9SAndroid Build Coastguard Worker 2696*da0073e9SAndroid Build Coastguard Worker def TwoTensor_foo(mode, func, types, args, kwargs): 2697*da0073e9SAndroid Build Coastguard Worker nonlocal called 2698*da0073e9SAndroid Build Coastguard Worker called += 1 2699*da0073e9SAndroid Build Coastguard Worker return x.sin() 2700*da0073e9SAndroid Build Coastguard Worker 2701*da0073e9SAndroid Build Coastguard Worker m._register_torch_dispatch_rule("foo", TwoTensorMode, TwoTensor_foo) 2702*da0073e9SAndroid Build Coastguard Worker 2703*da0073e9SAndroid Build Coastguard Worker with TwoTensorMode(): 2704*da0073e9SAndroid Build Coastguard Worker out = f(x) 2705*da0073e9SAndroid Build Coastguard Worker out2 = x.cos() 2706*da0073e9SAndroid Build Coastguard Worker 2707*da0073e9SAndroid Build Coastguard Worker self.assertEqual(called, 1) 2708*da0073e9SAndroid Build Coastguard Worker 2709*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") 2710*da0073e9SAndroid Build Coastguard Worker @parametrize("idx", [0, 1, 2, 3, 4, 5]) 2711*da0073e9SAndroid Build Coastguard Worker def test_library_register_fake_source(self, idx): 2712*da0073e9SAndroid Build Coastguard Worker opname = f"source{idx}" 2713*da0073e9SAndroid Build Coastguard Worker op = getattr(torch.ops._torch_testing, opname).default 2714*da0073e9SAndroid Build Coastguard Worker entry = torch._library.simple_registry.singleton.find(op._name) 2715*da0073e9SAndroid Build Coastguard Worker source = entry.fake_impl.kernel.source 2716*da0073e9SAndroid Build Coastguard Worker assert source is not None 2717*da0073e9SAndroid Build Coastguard Worker self.assertTrue("custom_op_db.py" in source) 2718*da0073e9SAndroid Build Coastguard Worker 2719*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") 2720*da0073e9SAndroid Build Coastguard Worker def test_library_register_fake(self): 2721*da0073e9SAndroid Build Coastguard Worker for mode in ["function", "qualname", "opoverload"]: 2722*da0073e9SAndroid Build Coastguard Worker 2723*da0073e9SAndroid Build Coastguard Worker @torch.library.custom_op("_torch_testing::add", mutates_args=()) 2724*da0073e9SAndroid Build Coastguard Worker def add(x: Tensor, y: float) -> Tensor: 2725*da0073e9SAndroid Build Coastguard Worker x_np = x.cpu().numpy() 2726*da0073e9SAndroid Build Coastguard Worker out_np = x_np + y 2727*da0073e9SAndroid Build Coastguard Worker return torch.from_numpy(out_np).to(x.device) 2728*da0073e9SAndroid Build Coastguard Worker 2729*da0073e9SAndroid Build Coastguard Worker called = False 2730*da0073e9SAndroid Build Coastguard Worker 2731*da0073e9SAndroid Build Coastguard Worker if mode == "function": 2732*da0073e9SAndroid Build Coastguard Worker dec = torch.library.register_fake(add) 2733*da0073e9SAndroid Build Coastguard Worker self.assertIsNotNone(dec) 2734*da0073e9SAndroid Build Coastguard Worker elif mode == "qualname": 2735*da0073e9SAndroid Build Coastguard Worker dec = torch.library.register_fake("_torch_testing::add") 2736*da0073e9SAndroid Build Coastguard Worker self.assertIsNotNone(dec) 2737*da0073e9SAndroid Build Coastguard Worker elif mode == "opoverload": 2738*da0073e9SAndroid Build Coastguard Worker dec = torch.library.register_fake(torch.ops._torch_testing.add.default) 2739*da0073e9SAndroid Build Coastguard Worker self.assertIsNotNone(dec) 2740*da0073e9SAndroid Build Coastguard Worker else: 2741*da0073e9SAndroid Build Coastguard Worker raise AssertionError("should not get here") 2742*da0073e9SAndroid Build Coastguard Worker 2743*da0073e9SAndroid Build Coastguard Worker @dec 2744*da0073e9SAndroid Build Coastguard Worker def _(x, y): 2745*da0073e9SAndroid Build Coastguard Worker nonlocal called 2746*da0073e9SAndroid Build Coastguard Worker called = True 2747*da0073e9SAndroid Build Coastguard Worker return torch.empty_like(x) 2748*da0073e9SAndroid Build Coastguard Worker 2749*da0073e9SAndroid Build Coastguard Worker with torch._subclasses.fake_tensor.FakeTensorMode(): 2750*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3) 2751*da0073e9SAndroid Build Coastguard Worker y = 3.14 2752*da0073e9SAndroid Build Coastguard Worker z = add(x, y) 2753*da0073e9SAndroid Build Coastguard Worker self.assertEqual(z.shape, x.shape) 2754*da0073e9SAndroid Build Coastguard Worker self.assertTrue(called) 2755*da0073e9SAndroid Build Coastguard Worker 2756*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") 2757*da0073e9SAndroid Build Coastguard Worker def test_library_register_torch_dispatch(self): 2758*da0073e9SAndroid Build Coastguard Worker for mode in ["function", "qualname", "opoverload"]: 2759*da0073e9SAndroid Build Coastguard Worker 2760*da0073e9SAndroid Build Coastguard Worker class MyMode(torch.utils._python_dispatch.TorchDispatchMode): 2761*da0073e9SAndroid Build Coastguard Worker def __torch_dispatch__(self, func, types, args=(), kwargs=None): 2762*da0073e9SAndroid Build Coastguard Worker return func(*args, **kwargs) 2763*da0073e9SAndroid Build Coastguard Worker 2764*da0073e9SAndroid Build Coastguard Worker @torch.library.custom_op("_torch_testing::add", mutates_args=()) 2765*da0073e9SAndroid Build Coastguard Worker def add(x: Tensor, y: float) -> Tensor: 2766*da0073e9SAndroid Build Coastguard Worker x_np = x.cpu().numpy() 2767*da0073e9SAndroid Build Coastguard Worker out_np = x_np + y 2768*da0073e9SAndroid Build Coastguard Worker return torch.from_numpy(out_np).to(x.device) 2769*da0073e9SAndroid Build Coastguard Worker 2770*da0073e9SAndroid Build Coastguard Worker called = False 2771*da0073e9SAndroid Build Coastguard Worker 2772*da0073e9SAndroid Build Coastguard Worker if mode == "function": 2773*da0073e9SAndroid Build Coastguard Worker dec = torch.library.register_torch_dispatch(add, MyMode) 2774*da0073e9SAndroid Build Coastguard Worker self.assertIsNotNone(dec) 2775*da0073e9SAndroid Build Coastguard Worker elif mode == "qualname": 2776*da0073e9SAndroid Build Coastguard Worker dec = torch.library.register_torch_dispatch( 2777*da0073e9SAndroid Build Coastguard Worker "_torch_testing::add", MyMode 2778*da0073e9SAndroid Build Coastguard Worker ) 2779*da0073e9SAndroid Build Coastguard Worker self.assertIsNotNone(dec) 2780*da0073e9SAndroid Build Coastguard Worker elif mode == "opoverload": 2781*da0073e9SAndroid Build Coastguard Worker dec = torch.library.register_torch_dispatch( 2782*da0073e9SAndroid Build Coastguard Worker torch.ops._torch_testing.add.default, MyMode 2783*da0073e9SAndroid Build Coastguard Worker ) 2784*da0073e9SAndroid Build Coastguard Worker self.assertIsNotNone(dec) 2785*da0073e9SAndroid Build Coastguard Worker else: 2786*da0073e9SAndroid Build Coastguard Worker raise AssertionError("should not get here") 2787*da0073e9SAndroid Build Coastguard Worker 2788*da0073e9SAndroid Build Coastguard Worker @dec 2789*da0073e9SAndroid Build Coastguard Worker def _(mode, func, types, args, kwargs): 2790*da0073e9SAndroid Build Coastguard Worker nonlocal called 2791*da0073e9SAndroid Build Coastguard Worker called = True 2792*da0073e9SAndroid Build Coastguard Worker return func(*args, **kwargs) 2793*da0073e9SAndroid Build Coastguard Worker 2794*da0073e9SAndroid Build Coastguard Worker with MyMode(): 2795*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3) 2796*da0073e9SAndroid Build Coastguard Worker y = 3.14 2797*da0073e9SAndroid Build Coastguard Worker z = add(x, y) 2798*da0073e9SAndroid Build Coastguard Worker self.assertEqual(z.shape, x.shape) 2799*da0073e9SAndroid Build Coastguard Worker self.assertTrue(called) 2800*da0073e9SAndroid Build Coastguard Worker 2801*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") 2802*da0073e9SAndroid Build Coastguard Worker def test_library_register_torch_dispatch_low_level(self): 2803*da0073e9SAndroid Build Coastguard Worker modes = ["qualname", "opoverload"] 2804*da0073e9SAndroid Build Coastguard Worker calls = ["decorator", "function"] 2805*da0073e9SAndroid Build Coastguard Worker device_types_options = [("cpu", "cuda"), "cpu", None] 2806*da0073e9SAndroid Build Coastguard Worker 2807*da0073e9SAndroid Build Coastguard Worker for mode, call, device_types in itertools.product( 2808*da0073e9SAndroid Build Coastguard Worker modes, calls, device_types_options 2809*da0073e9SAndroid Build Coastguard Worker ): 2810*da0073e9SAndroid Build Coastguard Worker with torch.library._scoped_library("_torch_testing", "FRAGMENT") as lib: 2811*da0073e9SAndroid Build Coastguard Worker lib.define("add10(Tensor x, float y) -> Tensor") 2812*da0073e9SAndroid Build Coastguard Worker 2813*da0073e9SAndroid Build Coastguard Worker if mode == "qualname": 2814*da0073e9SAndroid Build Coastguard Worker op = "_torch_testing::add10" 2815*da0073e9SAndroid Build Coastguard Worker else: 2816*da0073e9SAndroid Build Coastguard Worker assert mode == "opoverload" 2817*da0073e9SAndroid Build Coastguard Worker op = torch.ops._torch_testing.add10.default 2818*da0073e9SAndroid Build Coastguard Worker 2819*da0073e9SAndroid Build Coastguard Worker called = False 2820*da0073e9SAndroid Build Coastguard Worker 2821*da0073e9SAndroid Build Coastguard Worker class MyMode(torch.utils._python_dispatch.TorchDispatchMode): 2822*da0073e9SAndroid Build Coastguard Worker def __torch_dispatch__(self, func, types, args=(), kwargs=None): 2823*da0073e9SAndroid Build Coastguard Worker return func(*args, **kwargs) 2824*da0073e9SAndroid Build Coastguard Worker 2825*da0073e9SAndroid Build Coastguard Worker if call == "decorator": 2826*da0073e9SAndroid Build Coastguard Worker 2827*da0073e9SAndroid Build Coastguard Worker @torch.library.register_torch_dispatch(op, MyMode, lib=lib) 2828*da0073e9SAndroid Build Coastguard Worker def _(mode, func, types, args, kwargs): 2829*da0073e9SAndroid Build Coastguard Worker x, y = args 2830*da0073e9SAndroid Build Coastguard Worker nonlocal called 2831*da0073e9SAndroid Build Coastguard Worker called = True 2832*da0073e9SAndroid Build Coastguard Worker return x + y 2833*da0073e9SAndroid Build Coastguard Worker 2834*da0073e9SAndroid Build Coastguard Worker else: 2835*da0073e9SAndroid Build Coastguard Worker assert call == "function" 2836*da0073e9SAndroid Build Coastguard Worker 2837*da0073e9SAndroid Build Coastguard Worker def add_stuff(mode, func, types, args, kwargs): 2838*da0073e9SAndroid Build Coastguard Worker x, y = args 2839*da0073e9SAndroid Build Coastguard Worker nonlocal called 2840*da0073e9SAndroid Build Coastguard Worker called = True 2841*da0073e9SAndroid Build Coastguard Worker return x + y 2842*da0073e9SAndroid Build Coastguard Worker 2843*da0073e9SAndroid Build Coastguard Worker torch.library.register_torch_dispatch( 2844*da0073e9SAndroid Build Coastguard Worker op, MyMode, add_stuff, lib=lib 2845*da0073e9SAndroid Build Coastguard Worker ) 2846*da0073e9SAndroid Build Coastguard Worker 2847*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3) 2848*da0073e9SAndroid Build Coastguard Worker y = 3.14 2849*da0073e9SAndroid Build Coastguard Worker with MyMode(): 2850*da0073e9SAndroid Build Coastguard Worker z = torch.ops._torch_testing.add10.default(x, y) 2851*da0073e9SAndroid Build Coastguard Worker self.assertEqual(z, x + y) 2852*da0073e9SAndroid Build Coastguard Worker self.assertTrue(called) 2853*da0073e9SAndroid Build Coastguard Worker 2854*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") 2855*da0073e9SAndroid Build Coastguard Worker def test_library_register_kernel(self): 2856*da0073e9SAndroid Build Coastguard Worker modes = ["function", "qualname", "opoverload"] 2857*da0073e9SAndroid Build Coastguard Worker calls = ["decorator", "function"] 2858*da0073e9SAndroid Build Coastguard Worker device_types_options = ["cpu", None] 2859*da0073e9SAndroid Build Coastguard Worker 2860*da0073e9SAndroid Build Coastguard Worker for mode, call, device_types in itertools.product( 2861*da0073e9SAndroid Build Coastguard Worker modes, calls, device_types_options 2862*da0073e9SAndroid Build Coastguard Worker ): 2863*da0073e9SAndroid Build Coastguard Worker 2864*da0073e9SAndroid Build Coastguard Worker @torch.library.custom_op( 2865*da0073e9SAndroid Build Coastguard Worker "_torch_testing::add", mutates_args=(), device_types="cuda" 2866*da0073e9SAndroid Build Coastguard Worker ) 2867*da0073e9SAndroid Build Coastguard Worker def add(x: Tensor, y: float) -> Tensor: 2868*da0073e9SAndroid Build Coastguard Worker x_np = x.cpu().numpy() 2869*da0073e9SAndroid Build Coastguard Worker out_np = x_np + y 2870*da0073e9SAndroid Build Coastguard Worker return torch.from_numpy(out_np).to(x.device) 2871*da0073e9SAndroid Build Coastguard Worker 2872*da0073e9SAndroid Build Coastguard Worker if mode == "function": 2873*da0073e9SAndroid Build Coastguard Worker op = add 2874*da0073e9SAndroid Build Coastguard Worker elif mode == "qualname": 2875*da0073e9SAndroid Build Coastguard Worker op = "_torch_testing::add" 2876*da0073e9SAndroid Build Coastguard Worker else: 2877*da0073e9SAndroid Build Coastguard Worker assert mode == "opoverload" 2878*da0073e9SAndroid Build Coastguard Worker op = torch.ops._torch_testing.add.default 2879*da0073e9SAndroid Build Coastguard Worker 2880*da0073e9SAndroid Build Coastguard Worker called = False 2881*da0073e9SAndroid Build Coastguard Worker 2882*da0073e9SAndroid Build Coastguard Worker if call == "decorator": 2883*da0073e9SAndroid Build Coastguard Worker 2884*da0073e9SAndroid Build Coastguard Worker @torch.library.register_kernel(op, device_types) 2885*da0073e9SAndroid Build Coastguard Worker def _(x, y): 2886*da0073e9SAndroid Build Coastguard Worker nonlocal called 2887*da0073e9SAndroid Build Coastguard Worker called = True 2888*da0073e9SAndroid Build Coastguard Worker x_np = x.numpy() 2889*da0073e9SAndroid Build Coastguard Worker out_np = x_np + y 2890*da0073e9SAndroid Build Coastguard Worker return torch.from_numpy(out_np) 2891*da0073e9SAndroid Build Coastguard Worker 2892*da0073e9SAndroid Build Coastguard Worker else: 2893*da0073e9SAndroid Build Coastguard Worker assert call == "function" 2894*da0073e9SAndroid Build Coastguard Worker 2895*da0073e9SAndroid Build Coastguard Worker def add_cpu(x, y): 2896*da0073e9SAndroid Build Coastguard Worker nonlocal called 2897*da0073e9SAndroid Build Coastguard Worker called = True 2898*da0073e9SAndroid Build Coastguard Worker x_np = x.numpy() 2899*da0073e9SAndroid Build Coastguard Worker out_np = x_np + y 2900*da0073e9SAndroid Build Coastguard Worker return torch.from_numpy(out_np) 2901*da0073e9SAndroid Build Coastguard Worker 2902*da0073e9SAndroid Build Coastguard Worker torch.library.register_kernel(op, device_types, add_cpu) 2903*da0073e9SAndroid Build Coastguard Worker 2904*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3) 2905*da0073e9SAndroid Build Coastguard Worker y = 3.14 2906*da0073e9SAndroid Build Coastguard Worker z = add(x, y) 2907*da0073e9SAndroid Build Coastguard Worker self.assertEqual(z, x + y) 2908*da0073e9SAndroid Build Coastguard Worker self.assertTrue(called) 2909*da0073e9SAndroid Build Coastguard Worker 2910*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") 2911*da0073e9SAndroid Build Coastguard Worker def test_library_register_kernel_low_level(self): 2912*da0073e9SAndroid Build Coastguard Worker modes = ["qualname", "opoverload"] 2913*da0073e9SAndroid Build Coastguard Worker calls = ["decorator", "function"] 2914*da0073e9SAndroid Build Coastguard Worker device_types_options = [("cpu", "cuda"), "cpu", None] 2915*da0073e9SAndroid Build Coastguard Worker 2916*da0073e9SAndroid Build Coastguard Worker for mode, call, device_types in itertools.product( 2917*da0073e9SAndroid Build Coastguard Worker modes, calls, device_types_options 2918*da0073e9SAndroid Build Coastguard Worker ): 2919*da0073e9SAndroid Build Coastguard Worker with torch.library._scoped_library("_torch_testing", "FRAGMENT") as lib: 2920*da0073e9SAndroid Build Coastguard Worker lib.define("add9(Tensor x, float y) -> Tensor") 2921*da0073e9SAndroid Build Coastguard Worker 2922*da0073e9SAndroid Build Coastguard Worker if mode == "qualname": 2923*da0073e9SAndroid Build Coastguard Worker op = "_torch_testing::add9" 2924*da0073e9SAndroid Build Coastguard Worker else: 2925*da0073e9SAndroid Build Coastguard Worker assert mode == "opoverload" 2926*da0073e9SAndroid Build Coastguard Worker op = torch.ops._torch_testing.add9.default 2927*da0073e9SAndroid Build Coastguard Worker 2928*da0073e9SAndroid Build Coastguard Worker called = False 2929*da0073e9SAndroid Build Coastguard Worker 2930*da0073e9SAndroid Build Coastguard Worker if call == "decorator": 2931*da0073e9SAndroid Build Coastguard Worker 2932*da0073e9SAndroid Build Coastguard Worker @torch.library.register_kernel(op, device_types, lib=lib) 2933*da0073e9SAndroid Build Coastguard Worker def _(x, y): 2934*da0073e9SAndroid Build Coastguard Worker nonlocal called 2935*da0073e9SAndroid Build Coastguard Worker called = True 2936*da0073e9SAndroid Build Coastguard Worker x_np = x.numpy() 2937*da0073e9SAndroid Build Coastguard Worker out_np = x_np + y 2938*da0073e9SAndroid Build Coastguard Worker return torch.from_numpy(out_np) 2939*da0073e9SAndroid Build Coastguard Worker 2940*da0073e9SAndroid Build Coastguard Worker else: 2941*da0073e9SAndroid Build Coastguard Worker assert call == "function" 2942*da0073e9SAndroid Build Coastguard Worker 2943*da0073e9SAndroid Build Coastguard Worker def add_cpu(x, y): 2944*da0073e9SAndroid Build Coastguard Worker nonlocal called 2945*da0073e9SAndroid Build Coastguard Worker called = True 2946*da0073e9SAndroid Build Coastguard Worker x_np = x.numpy() 2947*da0073e9SAndroid Build Coastguard Worker out_np = x_np + y 2948*da0073e9SAndroid Build Coastguard Worker return torch.from_numpy(out_np) 2949*da0073e9SAndroid Build Coastguard Worker 2950*da0073e9SAndroid Build Coastguard Worker torch.library.register_kernel(op, device_types, add_cpu, lib=lib) 2951*da0073e9SAndroid Build Coastguard Worker 2952*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3) 2953*da0073e9SAndroid Build Coastguard Worker y = 3.14 2954*da0073e9SAndroid Build Coastguard Worker z = torch.ops._torch_testing.add9.default(x, y) 2955*da0073e9SAndroid Build Coastguard Worker self.assertEqual(z, x + y) 2956*da0073e9SAndroid Build Coastguard Worker self.assertTrue(called) 2957*da0073e9SAndroid Build Coastguard Worker 2958*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") 2959*da0073e9SAndroid Build Coastguard Worker def test_library_register_autograd(self): 2960*da0073e9SAndroid Build Coastguard Worker for mode in ["function", "qualname", "opoverload"]: 2961*da0073e9SAndroid Build Coastguard Worker 2962*da0073e9SAndroid Build Coastguard Worker @torch.library.custom_op("mylib::numpy_sin", mutates_args=()) 2963*da0073e9SAndroid Build Coastguard Worker def numpy_sin(x: Tensor) -> Tensor: 2964*da0073e9SAndroid Build Coastguard Worker x_np = x.cpu().numpy() 2965*da0073e9SAndroid Build Coastguard Worker y_np = np.sin(x_np) 2966*da0073e9SAndroid Build Coastguard Worker return torch.from_numpy(y_np).to(device=x.device) 2967*da0073e9SAndroid Build Coastguard Worker 2968*da0073e9SAndroid Build Coastguard Worker def setup_context(ctx, inputs, output) -> Tensor: 2969*da0073e9SAndroid Build Coastguard Worker (x,) = inputs 2970*da0073e9SAndroid Build Coastguard Worker ctx.save_for_backward(x) 2971*da0073e9SAndroid Build Coastguard Worker 2972*da0073e9SAndroid Build Coastguard Worker called = False 2973*da0073e9SAndroid Build Coastguard Worker 2974*da0073e9SAndroid Build Coastguard Worker def backward(ctx, grad): 2975*da0073e9SAndroid Build Coastguard Worker nonlocal called 2976*da0073e9SAndroid Build Coastguard Worker called = True 2977*da0073e9SAndroid Build Coastguard Worker (x,) = ctx.saved_tensors 2978*da0073e9SAndroid Build Coastguard Worker return grad * x.cos() 2979*da0073e9SAndroid Build Coastguard Worker 2980*da0073e9SAndroid Build Coastguard Worker if mode == "function": 2981*da0073e9SAndroid Build Coastguard Worker torch.library.register_autograd( 2982*da0073e9SAndroid Build Coastguard Worker numpy_sin, backward, setup_context=setup_context 2983*da0073e9SAndroid Build Coastguard Worker ) 2984*da0073e9SAndroid Build Coastguard Worker elif mode == "qualname": 2985*da0073e9SAndroid Build Coastguard Worker torch.library.register_autograd( 2986*da0073e9SAndroid Build Coastguard Worker "mylib::numpy_sin", backward, setup_context=setup_context 2987*da0073e9SAndroid Build Coastguard Worker ) 2988*da0073e9SAndroid Build Coastguard Worker elif mode == "opoverload": 2989*da0073e9SAndroid Build Coastguard Worker torch.library.register_autograd( 2990*da0073e9SAndroid Build Coastguard Worker torch.ops.mylib.numpy_sin.default, 2991*da0073e9SAndroid Build Coastguard Worker backward, 2992*da0073e9SAndroid Build Coastguard Worker setup_context=setup_context, 2993*da0073e9SAndroid Build Coastguard Worker ) 2994*da0073e9SAndroid Build Coastguard Worker 2995*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3, requires_grad=True) 2996*da0073e9SAndroid Build Coastguard Worker y = numpy_sin(x) 2997*da0073e9SAndroid Build Coastguard Worker (grad_x,) = torch.autograd.grad(y, x, torch.ones_like(y)) 2998*da0073e9SAndroid Build Coastguard Worker self.assertTrue(called) 2999*da0073e9SAndroid Build Coastguard Worker self.assertEqual(grad_x, x.cos()) 3000*da0073e9SAndroid Build Coastguard Worker 3001*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") 3002*da0073e9SAndroid Build Coastguard Worker def test_library_register_autograd_low_level(self): 3003*da0073e9SAndroid Build Coastguard Worker for mode in ["qualname", "opoverload"]: 3004*da0073e9SAndroid Build Coastguard Worker with torch.library._scoped_library("_torch_testing", "FRAGMENT") as lib: 3005*da0073e9SAndroid Build Coastguard Worker lib.define("sin5(Tensor x) -> Tensor") 3006*da0073e9SAndroid Build Coastguard Worker 3007*da0073e9SAndroid Build Coastguard Worker def numpy_sin(x: Tensor) -> Tensor: 3008*da0073e9SAndroid Build Coastguard Worker x_np = x.cpu().detach().numpy() 3009*da0073e9SAndroid Build Coastguard Worker y_np = np.sin(x_np) 3010*da0073e9SAndroid Build Coastguard Worker return torch.from_numpy(y_np).to(device=x.device) 3011*da0073e9SAndroid Build Coastguard Worker 3012*da0073e9SAndroid Build Coastguard Worker def setup_context(ctx, inputs, output) -> Tensor: 3013*da0073e9SAndroid Build Coastguard Worker (x,) = inputs 3014*da0073e9SAndroid Build Coastguard Worker ctx.save_for_backward(x) 3015*da0073e9SAndroid Build Coastguard Worker 3016*da0073e9SAndroid Build Coastguard Worker called = False 3017*da0073e9SAndroid Build Coastguard Worker 3018*da0073e9SAndroid Build Coastguard Worker def backward(ctx, grad): 3019*da0073e9SAndroid Build Coastguard Worker nonlocal called 3020*da0073e9SAndroid Build Coastguard Worker called = True 3021*da0073e9SAndroid Build Coastguard Worker (x,) = ctx.saved_tensors 3022*da0073e9SAndroid Build Coastguard Worker return grad * x.cos() 3023*da0073e9SAndroid Build Coastguard Worker 3024*da0073e9SAndroid Build Coastguard Worker lib.impl("sin5", numpy_sin, "CPU") 3025*da0073e9SAndroid Build Coastguard Worker 3026*da0073e9SAndroid Build Coastguard Worker called = False 3027*da0073e9SAndroid Build Coastguard Worker 3028*da0073e9SAndroid Build Coastguard Worker if mode == "qualname": 3029*da0073e9SAndroid Build Coastguard Worker torch.library.register_autograd( 3030*da0073e9SAndroid Build Coastguard Worker "_torch_testing::sin5", 3031*da0073e9SAndroid Build Coastguard Worker backward, 3032*da0073e9SAndroid Build Coastguard Worker setup_context=setup_context, 3033*da0073e9SAndroid Build Coastguard Worker lib=lib, 3034*da0073e9SAndroid Build Coastguard Worker ) 3035*da0073e9SAndroid Build Coastguard Worker elif mode == "opoverload": 3036*da0073e9SAndroid Build Coastguard Worker torch.library.register_autograd( 3037*da0073e9SAndroid Build Coastguard Worker torch.ops._torch_testing.sin5.default, 3038*da0073e9SAndroid Build Coastguard Worker backward, 3039*da0073e9SAndroid Build Coastguard Worker setup_context=setup_context, 3040*da0073e9SAndroid Build Coastguard Worker lib=lib, 3041*da0073e9SAndroid Build Coastguard Worker ) 3042*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3, requires_grad=True) 3043*da0073e9SAndroid Build Coastguard Worker y = torch.ops._torch_testing.sin5(x) 3044*da0073e9SAndroid Build Coastguard Worker (grad_x,) = torch.autograd.grad(y, x, torch.ones_like(y)) 3045*da0073e9SAndroid Build Coastguard Worker self.assertTrue(called) 3046*da0073e9SAndroid Build Coastguard Worker self.assertEqual(grad_x, x.cos()) 3047*da0073e9SAndroid Build Coastguard Worker 3048*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") 3049*da0073e9SAndroid Build Coastguard Worker def test_fake(self): 3050*da0073e9SAndroid Build Coastguard Worker @torch.library.custom_op("_torch_testing::add", mutates_args=()) 3051*da0073e9SAndroid Build Coastguard Worker def add(x: Tensor, y: float) -> Tensor: 3052*da0073e9SAndroid Build Coastguard Worker x_np = x.cpu().numpy() 3053*da0073e9SAndroid Build Coastguard Worker out_np = x_np + y 3054*da0073e9SAndroid Build Coastguard Worker return torch.from_numpy(out_np).to(x.device) 3055*da0073e9SAndroid Build Coastguard Worker 3056*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3) 3057*da0073e9SAndroid Build Coastguard Worker y = 3.14 3058*da0073e9SAndroid Build Coastguard Worker z = add(x, y) 3059*da0073e9SAndroid Build Coastguard Worker self.assertEqual(z, x + y) 3060*da0073e9SAndroid Build Coastguard Worker 3061*da0073e9SAndroid Build Coastguard Worker try: 3062*da0073e9SAndroid Build Coastguard Worker with torch._subclasses.fake_tensor.FakeTensorMode(): 3063*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3) 3064*da0073e9SAndroid Build Coastguard Worker add(x, y) 3065*da0073e9SAndroid Build Coastguard Worker raise AssertionError("should not be hit") 3066*da0073e9SAndroid Build Coastguard Worker except RuntimeError as e: 3067*da0073e9SAndroid Build Coastguard Worker abstract_impl_error_msg = str(e) 3068*da0073e9SAndroid Build Coastguard Worker abstract_impl_error_msg = re.sub( 3069*da0073e9SAndroid Build Coastguard Worker r"0x.*>\)>", "0xDEADBEEF>)>", abstract_impl_error_msg 3070*da0073e9SAndroid Build Coastguard Worker ).replace(". ", ".\n") 3071*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 3072*da0073e9SAndroid Build Coastguard Worker abstract_impl_error_msg, 3073*da0073e9SAndroid Build Coastguard Worker """\ 3074*da0073e9SAndroid Build Coastguard WorkerThere was no fake impl registered for <CustomOpDef(_torch_testing::add)>. 3075*da0073e9SAndroid Build Coastguard WorkerThis is necessary for torch.compile/export/fx tracing to work. 3076*da0073e9SAndroid Build Coastguard WorkerPlease use `add.register_fake` to add an fake impl.""", 3077*da0073e9SAndroid Build Coastguard Worker ) 3078*da0073e9SAndroid Build Coastguard Worker 3079*da0073e9SAndroid Build Coastguard Worker if not IS_WINDOWS: 3080*da0073e9SAndroid Build Coastguard Worker 3081*da0073e9SAndroid Build Coastguard Worker @torch.compile(backend="eager") 3082*da0073e9SAndroid Build Coastguard Worker def f(x, y): 3083*da0073e9SAndroid Build Coastguard Worker return add(x, y) 3084*da0073e9SAndroid Build Coastguard Worker 3085*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3) 3086*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "no fake impl"): 3087*da0073e9SAndroid Build Coastguard Worker f(x, y) 3088*da0073e9SAndroid Build Coastguard Worker 3089*da0073e9SAndroid Build Coastguard Worker abstract_called = False 3090*da0073e9SAndroid Build Coastguard Worker 3091*da0073e9SAndroid Build Coastguard Worker @add.register_fake 3092*da0073e9SAndroid Build Coastguard Worker def _(x, y): 3093*da0073e9SAndroid Build Coastguard Worker nonlocal abstract_called 3094*da0073e9SAndroid Build Coastguard Worker abstract_called = True 3095*da0073e9SAndroid Build Coastguard Worker return torch.empty_like(x) 3096*da0073e9SAndroid Build Coastguard Worker 3097*da0073e9SAndroid Build Coastguard Worker with torch._subclasses.fake_tensor.FakeTensorMode(): 3098*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3) 3099*da0073e9SAndroid Build Coastguard Worker z = add(x, y) 3100*da0073e9SAndroid Build Coastguard Worker self.assertEqual(z.shape, x.shape) 3101*da0073e9SAndroid Build Coastguard Worker self.assertTrue(abstract_called) 3102*da0073e9SAndroid Build Coastguard Worker 3103*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("recursive dynamo") 3104*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work on windows") 3105*da0073e9SAndroid Build Coastguard Worker def test_compile(self): 3106*da0073e9SAndroid Build Coastguard Worker called_impl = False 3107*da0073e9SAndroid Build Coastguard Worker called_abstract = False 3108*da0073e9SAndroid Build Coastguard Worker 3109*da0073e9SAndroid Build Coastguard Worker @torch.library.custom_op("_torch_testing::linear", mutates_args=()) 3110*da0073e9SAndroid Build Coastguard Worker def custom_linear(x: Tensor, weight: Tensor, bias: Tensor) -> Tensor: 3111*da0073e9SAndroid Build Coastguard Worker nonlocal called_impl 3112*da0073e9SAndroid Build Coastguard Worker called_impl = True 3113*da0073e9SAndroid Build Coastguard Worker x_np = x.numpy() 3114*da0073e9SAndroid Build Coastguard Worker w_np = weight.numpy() 3115*da0073e9SAndroid Build Coastguard Worker b_np = bias.numpy() 3116*da0073e9SAndroid Build Coastguard Worker out_np = np.add(x_np @ w_np.T, bias) 3117*da0073e9SAndroid Build Coastguard Worker return out_np 3118*da0073e9SAndroid Build Coastguard Worker 3119*da0073e9SAndroid Build Coastguard Worker @custom_linear.register_fake 3120*da0073e9SAndroid Build Coastguard Worker def _(x, weight, bias): 3121*da0073e9SAndroid Build Coastguard Worker nonlocal called_abstract 3122*da0073e9SAndroid Build Coastguard Worker called_abstract = True 3123*da0073e9SAndroid Build Coastguard Worker assert x.dim() == 2 3124*da0073e9SAndroid Build Coastguard Worker assert weight.dim() == 2 3125*da0073e9SAndroid Build Coastguard Worker assert bias.dim() == 1 3126*da0073e9SAndroid Build Coastguard Worker assert x.shape[1] == weight.shape[1] 3127*da0073e9SAndroid Build Coastguard Worker assert weight.shape[0] == bias.shape[0] 3128*da0073e9SAndroid Build Coastguard Worker assert x.device == weight.device 3129*da0073e9SAndroid Build Coastguard Worker return x.new_empty(x.size(0), weight.size(0)) 3130*da0073e9SAndroid Build Coastguard Worker 3131*da0073e9SAndroid Build Coastguard Worker x = torch.randn(2, 2) 3132*da0073e9SAndroid Build Coastguard Worker weight = torch.randn(2, 2) 3133*da0073e9SAndroid Build Coastguard Worker bias = torch.randn(2) 3134*da0073e9SAndroid Build Coastguard Worker out = torch.compile(custom_linear, backend="eager", fullgraph=True)( 3135*da0073e9SAndroid Build Coastguard Worker x, weight, bias 3136*da0073e9SAndroid Build Coastguard Worker ) 3137*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, torch.nn.functional.linear(x, weight, bias)) 3138*da0073e9SAndroid Build Coastguard Worker self.assertTrue(called_impl) 3139*da0073e9SAndroid Build Coastguard Worker self.assertTrue(called_abstract) 3140*da0073e9SAndroid Build Coastguard Worker 3141*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") 3142*da0073e9SAndroid Build Coastguard Worker def test_register_autograd_error_cases(self): 3143*da0073e9SAndroid Build Coastguard Worker @torch.library.custom_op("_torch_testing::g", mutates_args=()) 3144*da0073e9SAndroid Build Coastguard Worker def g(x: Tensor) -> Tensor: 3145*da0073e9SAndroid Build Coastguard Worker return x.sin() 3146*da0073e9SAndroid Build Coastguard Worker 3147*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3, requires_grad=True) 3148*da0073e9SAndroid Build Coastguard Worker y = g(x) 3149*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "no autograd formula"): 3150*da0073e9SAndroid Build Coastguard Worker y.sum().backward() 3151*da0073e9SAndroid Build Coastguard Worker 3152*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") 3153*da0073e9SAndroid Build Coastguard Worker def test_replacement(self): 3154*da0073e9SAndroid Build Coastguard Worker @torch.library.custom_op("_torch_testing::f", mutates_args=()) 3155*da0073e9SAndroid Build Coastguard Worker def f(x: Tensor) -> Tensor: 3156*da0073e9SAndroid Build Coastguard Worker return x.sin() 3157*da0073e9SAndroid Build Coastguard Worker 3158*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3) 3159*da0073e9SAndroid Build Coastguard Worker y = f(x) 3160*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y, x.sin()) 3161*da0073e9SAndroid Build Coastguard Worker 3162*da0073e9SAndroid Build Coastguard Worker @torch.library.custom_op("_torch_testing::f", mutates_args=()) 3163*da0073e9SAndroid Build Coastguard Worker def f(x: Tensor) -> Tensor: 3164*da0073e9SAndroid Build Coastguard Worker return x.cos() 3165*da0073e9SAndroid Build Coastguard Worker 3166*da0073e9SAndroid Build Coastguard Worker y = f(x) 3167*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y, x.cos()) 3168*da0073e9SAndroid Build Coastguard Worker 3169*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") 3170*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDA, "requires CUDA") 3171*da0073e9SAndroid Build Coastguard Worker def test_split_device(self): 3172*da0073e9SAndroid Build Coastguard Worker cpu_call_count = 0 3173*da0073e9SAndroid Build Coastguard Worker cuda_call_count = 0 3174*da0073e9SAndroid Build Coastguard Worker 3175*da0073e9SAndroid Build Coastguard Worker @torch.library.custom_op( 3176*da0073e9SAndroid Build Coastguard Worker "_torch_testing::f", mutates_args=(), device_types="cpu" 3177*da0073e9SAndroid Build Coastguard Worker ) 3178*da0073e9SAndroid Build Coastguard Worker def f(x: Tensor) -> Tensor: 3179*da0073e9SAndroid Build Coastguard Worker nonlocal cpu_call_count 3180*da0073e9SAndroid Build Coastguard Worker cpu_call_count += 1 3181*da0073e9SAndroid Build Coastguard Worker x_np = x.numpy() 3182*da0073e9SAndroid Build Coastguard Worker out_np = np.sin(x_np) 3183*da0073e9SAndroid Build Coastguard Worker return torch.from_numpy(out_np) 3184*da0073e9SAndroid Build Coastguard Worker 3185*da0073e9SAndroid Build Coastguard Worker @f.register_kernel("cuda") 3186*da0073e9SAndroid Build Coastguard Worker def _(x: Tensor) -> Tensor: 3187*da0073e9SAndroid Build Coastguard Worker nonlocal cuda_call_count 3188*da0073e9SAndroid Build Coastguard Worker cuda_call_count += 1 3189*da0073e9SAndroid Build Coastguard Worker x_np = x.cpu().numpy() 3190*da0073e9SAndroid Build Coastguard Worker out_np = np.sin(x_np) 3191*da0073e9SAndroid Build Coastguard Worker return torch.from_numpy(out_np).to(x.device) 3192*da0073e9SAndroid Build Coastguard Worker 3193*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3) 3194*da0073e9SAndroid Build Coastguard Worker y = f(x) 3195*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y, x.sin()) 3196*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cpu_call_count, 1) 3197*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cuda_call_count, 0) 3198*da0073e9SAndroid Build Coastguard Worker 3199*da0073e9SAndroid Build Coastguard Worker x = x.cuda() 3200*da0073e9SAndroid Build Coastguard Worker y = f(x) 3201*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y, x.sin()) 3202*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cpu_call_count, 1) 3203*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cuda_call_count, 1) 3204*da0073e9SAndroid Build Coastguard Worker 3205*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") 3206*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDA, "requires CUDA") 3207*da0073e9SAndroid Build Coastguard Worker def test_multi_types(self): 3208*da0073e9SAndroid Build Coastguard Worker @torch.library.custom_op( 3209*da0073e9SAndroid Build Coastguard Worker "_torch_testing::f", mutates_args=(), device_types=("cpu", "cuda") 3210*da0073e9SAndroid Build Coastguard Worker ) 3211*da0073e9SAndroid Build Coastguard Worker def f(x: Tensor) -> Tensor: 3212*da0073e9SAndroid Build Coastguard Worker x_np = x.cpu().numpy() 3213*da0073e9SAndroid Build Coastguard Worker out_np = np.sin(x_np) 3214*da0073e9SAndroid Build Coastguard Worker return torch.from_numpy(out_np).to(x.device) 3215*da0073e9SAndroid Build Coastguard Worker 3216*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3) 3217*da0073e9SAndroid Build Coastguard Worker y = f(x) 3218*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y, x.sin()) 3219*da0073e9SAndroid Build Coastguard Worker x = x.cuda() 3220*da0073e9SAndroid Build Coastguard Worker y = f(x) 3221*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y, x.sin()) 3222*da0073e9SAndroid Build Coastguard Worker 3223*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") 3224*da0073e9SAndroid Build Coastguard Worker def test_overloading(self): 3225*da0073e9SAndroid Build Coastguard Worker called_f = 0 3226*da0073e9SAndroid Build Coastguard Worker called_f1 = 0 3227*da0073e9SAndroid Build Coastguard Worker 3228*da0073e9SAndroid Build Coastguard Worker @torch.library.custom_op("_torch_testing::f", mutates_args=()) 3229*da0073e9SAndroid Build Coastguard Worker def f(x: Tensor) -> Tensor: 3230*da0073e9SAndroid Build Coastguard Worker nonlocal called_f 3231*da0073e9SAndroid Build Coastguard Worker called_f += 1 3232*da0073e9SAndroid Build Coastguard Worker return x.clone() 3233*da0073e9SAndroid Build Coastguard Worker 3234*da0073e9SAndroid Build Coastguard Worker x = torch.randn(2, 3) 3235*da0073e9SAndroid Build Coastguard Worker torch.ops._torch_testing.f(x) 3236*da0073e9SAndroid Build Coastguard Worker self.assertEqual(called_f, 1) 3237*da0073e9SAndroid Build Coastguard Worker 3238*da0073e9SAndroid Build Coastguard Worker @torch.library.custom_op("_torch_testing::f.overload", mutates_args=()) 3239*da0073e9SAndroid Build Coastguard Worker def f1(x: Tensor, y: Tensor) -> Tensor: 3240*da0073e9SAndroid Build Coastguard Worker nonlocal called_f1 3241*da0073e9SAndroid Build Coastguard Worker called_f1 += 1 3242*da0073e9SAndroid Build Coastguard Worker return x.clone() 3243*da0073e9SAndroid Build Coastguard Worker 3244*da0073e9SAndroid Build Coastguard Worker torch.ops._torch_testing.f(x, x) 3245*da0073e9SAndroid Build Coastguard Worker self.assertEqual(called_f1, 1) 3246*da0073e9SAndroid Build Coastguard Worker 3247*da0073e9SAndroid Build Coastguard Worker def test_disallows_output_aliasing(self): 3248*da0073e9SAndroid Build Coastguard Worker @torch.library.custom_op("_torch_testing::f", mutates_args=()) 3249*da0073e9SAndroid Build Coastguard Worker def f(x: Tensor) -> Tensor: 3250*da0073e9SAndroid Build Coastguard Worker return x.view(-1) 3251*da0073e9SAndroid Build Coastguard Worker 3252*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3) 3253*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "may not alias"): 3254*da0073e9SAndroid Build Coastguard Worker f(x) 3255*da0073e9SAndroid Build Coastguard Worker 3256*da0073e9SAndroid Build Coastguard Worker @torch.library.custom_op("_torch_testing::f", mutates_args=()) 3257*da0073e9SAndroid Build Coastguard Worker def f(x: Tensor) -> Tensor: 3258*da0073e9SAndroid Build Coastguard Worker return x 3259*da0073e9SAndroid Build Coastguard Worker 3260*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3) 3261*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "may not alias"): 3262*da0073e9SAndroid Build Coastguard Worker f(x) 3263*da0073e9SAndroid Build Coastguard Worker 3264*da0073e9SAndroid Build Coastguard Worker @torch.library.custom_op( 3265*da0073e9SAndroid Build Coastguard Worker "_torch_testing::f", mutates_args={"x"}, device_types="cpu" 3266*da0073e9SAndroid Build Coastguard Worker ) 3267*da0073e9SAndroid Build Coastguard Worker def numpy_sin_inplace(x: Tensor) -> Tensor: 3268*da0073e9SAndroid Build Coastguard Worker x_np = x.numpy() 3269*da0073e9SAndroid Build Coastguard Worker np.sin(x_np, out=x_np) 3270*da0073e9SAndroid Build Coastguard Worker return x 3271*da0073e9SAndroid Build Coastguard Worker 3272*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3) 3273*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "may not alias"): 3274*da0073e9SAndroid Build Coastguard Worker numpy_sin_inplace(x) 3275*da0073e9SAndroid Build Coastguard Worker 3276*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") 3277*da0073e9SAndroid Build Coastguard Worker def test_factory_function(self): 3278*da0073e9SAndroid Build Coastguard Worker @torch.library.custom_op( 3279*da0073e9SAndroid Build Coastguard Worker "_torch_testing::f", mutates_args={}, device_types="cpu" 3280*da0073e9SAndroid Build Coastguard Worker ) 3281*da0073e9SAndroid Build Coastguard Worker def f(device: torch.device) -> Tensor: 3282*da0073e9SAndroid Build Coastguard Worker return torch.ones(3) 3283*da0073e9SAndroid Build Coastguard Worker 3284*da0073e9SAndroid Build Coastguard Worker result = f(device="cpu") 3285*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result.device, torch.device("cpu")) 3286*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, torch.ones(3)) 3287*da0073e9SAndroid Build Coastguard Worker 3288*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 3289*da0073e9SAndroid Build Coastguard Worker RuntimeError, "f does not have a kernel registered for cuda" 3290*da0073e9SAndroid Build Coastguard Worker ): 3291*da0073e9SAndroid Build Coastguard Worker f("cuda") 3292*da0073e9SAndroid Build Coastguard Worker 3293*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 3294*da0073e9SAndroid Build Coastguard Worker ValueError, 3295*da0073e9SAndroid Build Coastguard Worker "Functions without tensor inputs are required to have a `device: torch.device` argument", 3296*da0073e9SAndroid Build Coastguard Worker ): 3297*da0073e9SAndroid Build Coastguard Worker 3298*da0073e9SAndroid Build Coastguard Worker @torch.library.custom_op( 3299*da0073e9SAndroid Build Coastguard Worker "_torch_testing::f2", mutates_args={}, device_types="cpu" 3300*da0073e9SAndroid Build Coastguard Worker ) 3301*da0073e9SAndroid Build Coastguard Worker def f2() -> Tensor: 3302*da0073e9SAndroid Build Coastguard Worker return torch.ones(3) 3303*da0073e9SAndroid Build Coastguard Worker 3304*da0073e9SAndroid Build Coastguard Worker @torch.library.custom_op("_torch_testing::f3", mutates_args={}) 3305*da0073e9SAndroid Build Coastguard Worker def f3() -> Tensor: 3306*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError("NYI") 3307*da0073e9SAndroid Build Coastguard Worker 3308*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 3309*da0073e9SAndroid Build Coastguard Worker ValueError, 3310*da0073e9SAndroid Build Coastguard Worker "Functions without tensor inputs are required to have a `device: torch.device` argument", 3311*da0073e9SAndroid Build Coastguard Worker ): 3312*da0073e9SAndroid Build Coastguard Worker 3313*da0073e9SAndroid Build Coastguard Worker @f3.register_kernel("cpu") 3314*da0073e9SAndroid Build Coastguard Worker def _(): 3315*da0073e9SAndroid Build Coastguard Worker return torch.zeros(3) 3316*da0073e9SAndroid Build Coastguard Worker 3317*da0073e9SAndroid Build Coastguard Worker result = f(x) 3318*da0073e9SAndroid Build Coastguard Worker 3319*da0073e9SAndroid Build Coastguard Worker @torch.library.custom_op("_torch_testing::f4", mutates_args={}) 3320*da0073e9SAndroid Build Coastguard Worker def f4(device: torch.device) -> Tensor: 3321*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError("NYI") 3322*da0073e9SAndroid Build Coastguard Worker 3323*da0073e9SAndroid Build Coastguard Worker @f4.register_kernel("cpu") 3324*da0073e9SAndroid Build Coastguard Worker def _(device: torch.device): 3325*da0073e9SAndroid Build Coastguard Worker return torch.zeros(3) 3326*da0073e9SAndroid Build Coastguard Worker 3327*da0073e9SAndroid Build Coastguard Worker result = f(device="cpu") 3328*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result.device, torch.device("cpu")) 3329*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, torch.ones(3)) 3330*da0073e9SAndroid Build Coastguard Worker 3331*da0073e9SAndroid Build Coastguard Worker def test_library_schema_infer(self): 3332*da0073e9SAndroid Build Coastguard Worker def foo_impl(x: torch.Tensor) -> torch.Tensor: 3333*da0073e9SAndroid Build Coastguard Worker return x.sin() 3334*da0073e9SAndroid Build Coastguard Worker 3335*da0073e9SAndroid Build Coastguard Worker schema = torch.library.infer_schema(foo_impl, op_name="myop", mutates_args={}) 3336*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline(schema, "myop(Tensor x) -> Tensor") 3337*da0073e9SAndroid Build Coastguard Worker 3338*da0073e9SAndroid Build Coastguard Worker schema = torch.library.infer_schema(foo_impl, mutates_args={}) 3339*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline(schema, "(Tensor x) -> Tensor") 3340*da0073e9SAndroid Build Coastguard Worker 3341*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") 3342*da0073e9SAndroid Build Coastguard Worker def test_set_kernel_enabled(self): 3343*da0073e9SAndroid Build Coastguard Worker x = torch.ones(1) 3344*da0073e9SAndroid Build Coastguard Worker 3345*da0073e9SAndroid Build Coastguard Worker @torch.library.custom_op("mylib::f", mutates_args=()) 3346*da0073e9SAndroid Build Coastguard Worker def f(x: Tensor) -> Tensor: 3347*da0073e9SAndroid Build Coastguard Worker return x + 1 3348*da0073e9SAndroid Build Coastguard Worker 3349*da0073e9SAndroid Build Coastguard Worker self.assertEqual(f(x), x + 1) 3350*da0073e9SAndroid Build Coastguard Worker with self.assertLogs("torch._library.custom_ops") as captured: 3351*da0073e9SAndroid Build Coastguard Worker with f.set_kernel_enabled("gpu", enabled=False): 3352*da0073e9SAndroid Build Coastguard Worker self.assertEqual(f(x), x + 1) 3353*da0073e9SAndroid Build Coastguard Worker self.assertIn( 3354*da0073e9SAndroid Build Coastguard Worker "no kernel was registered for this device type", captured.output[0] 3355*da0073e9SAndroid Build Coastguard Worker ) 3356*da0073e9SAndroid Build Coastguard Worker 3357*da0073e9SAndroid Build Coastguard Worker @f.register_kernel("cpu") 3358*da0073e9SAndroid Build Coastguard Worker def _(x): 3359*da0073e9SAndroid Build Coastguard Worker return x + 2 3360*da0073e9SAndroid Build Coastguard Worker 3361*da0073e9SAndroid Build Coastguard Worker self.assertEqual(f(x), x + 2) 3362*da0073e9SAndroid Build Coastguard Worker 3363*da0073e9SAndroid Build Coastguard Worker with self.assertLogs("torch._library.custom_ops") as captured: 3364*da0073e9SAndroid Build Coastguard Worker with f.set_kernel_enabled("cpu", enabled=True): 3365*da0073e9SAndroid Build Coastguard Worker self.assertEqual(f(x), x + 2) 3366*da0073e9SAndroid Build Coastguard Worker self.assertIn("already enabled", captured.output[0]) 3367*da0073e9SAndroid Build Coastguard Worker 3368*da0073e9SAndroid Build Coastguard Worker with f.set_kernel_enabled("cpu", enabled=False): 3369*da0073e9SAndroid Build Coastguard Worker self.assertEqual(f(x), x + 1) 3370*da0073e9SAndroid Build Coastguard Worker 3371*da0073e9SAndroid Build Coastguard Worker with self.assertLogs("torch._library.custom_ops") as captured: 3372*da0073e9SAndroid Build Coastguard Worker with f.set_kernel_enabled("cpu", enabled=False): 3373*da0073e9SAndroid Build Coastguard Worker self.assertEqual(f(x), x + 1) 3374*da0073e9SAndroid Build Coastguard Worker self.assertIn("already disabled", captured.output[0]) 3375*da0073e9SAndroid Build Coastguard Worker 3376*da0073e9SAndroid Build Coastguard Worker self.assertEqual(f(x), x + 1) 3377*da0073e9SAndroid Build Coastguard Worker 3378*da0073e9SAndroid Build Coastguard Worker with f.set_kernel_enabled("cpu", enabled=True): 3379*da0073e9SAndroid Build Coastguard Worker self.assertEqual(f(x), x + 2) 3380*da0073e9SAndroid Build Coastguard Worker 3381*da0073e9SAndroid Build Coastguard Worker with f.set_kernel_enabled("cpu", enabled=False): 3382*da0073e9SAndroid Build Coastguard Worker self.assertEqual(f(x), x + 1) 3383*da0073e9SAndroid Build Coastguard Worker 3384*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") 3385*da0073e9SAndroid Build Coastguard Worker def test_register_vmap_kwargonly_low_level(self): 3386*da0073e9SAndroid Build Coastguard Worker with torch.library._scoped_library("_torch_testing", "FRAGMENT") as lib: 3387*da0073e9SAndroid Build Coastguard Worker lib.define("foo(Tensor x, *, float y) -> Tensor") 3388*da0073e9SAndroid Build Coastguard Worker called = False 3389*da0073e9SAndroid Build Coastguard Worker 3390*da0073e9SAndroid Build Coastguard Worker def foo_impl(x, *, y): 3391*da0073e9SAndroid Build Coastguard Worker return x * y 3392*da0073e9SAndroid Build Coastguard Worker 3393*da0073e9SAndroid Build Coastguard Worker lib.impl("foo", foo_impl, "CPU") 3394*da0073e9SAndroid Build Coastguard Worker 3395*da0073e9SAndroid Build Coastguard Worker def vmap(info, in_dims, x, *, y): 3396*da0073e9SAndroid Build Coastguard Worker nonlocal called 3397*da0073e9SAndroid Build Coastguard Worker called = True 3398*da0073e9SAndroid Build Coastguard Worker return x * y, 0 3399*da0073e9SAndroid Build Coastguard Worker 3400*da0073e9SAndroid Build Coastguard Worker torch.library.register_vmap("_torch_testing::foo", vmap, lib=lib) 3401*da0073e9SAndroid Build Coastguard Worker 3402*da0073e9SAndroid Build Coastguard Worker x = torch.ones(3) 3403*da0073e9SAndroid Build Coastguard Worker result = torch.vmap(torch.ops._torch_testing.foo)(x, y=3.14) 3404*da0073e9SAndroid Build Coastguard Worker self.assertTrue(called) 3405*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, torch.tensor([3.14, 3.14, 3.14])) 3406*da0073e9SAndroid Build Coastguard Worker 3407*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") 3408*da0073e9SAndroid Build Coastguard Worker def test_register_vmap_defaults(self): 3409*da0073e9SAndroid Build Coastguard Worker with torch.library._scoped_library("_torch_testing", "FRAGMENT") as lib: 3410*da0073e9SAndroid Build Coastguard Worker lib.define("foo(Tensor w, int x = 2, *, int y = 3, int z) -> Tensor") 3411*da0073e9SAndroid Build Coastguard Worker 3412*da0073e9SAndroid Build Coastguard Worker def foo_impl(w, x=2, *, y=3, z): 3413*da0073e9SAndroid Build Coastguard Worker return w * x * y * z 3414*da0073e9SAndroid Build Coastguard Worker 3415*da0073e9SAndroid Build Coastguard Worker lib.impl("foo", foo_impl, "CPU") 3416*da0073e9SAndroid Build Coastguard Worker 3417*da0073e9SAndroid Build Coastguard Worker called = False 3418*da0073e9SAndroid Build Coastguard Worker 3419*da0073e9SAndroid Build Coastguard Worker def vmap(info, in_dims, w, x=2, *, y=3, z): 3420*da0073e9SAndroid Build Coastguard Worker nonlocal called 3421*da0073e9SAndroid Build Coastguard Worker called = True 3422*da0073e9SAndroid Build Coastguard Worker return w * x * y * z, 0 3423*da0073e9SAndroid Build Coastguard Worker 3424*da0073e9SAndroid Build Coastguard Worker torch.library.register_vmap("_torch_testing::foo", vmap, lib=lib) 3425*da0073e9SAndroid Build Coastguard Worker 3426*da0073e9SAndroid Build Coastguard Worker w = torch.ones(3) 3427*da0073e9SAndroid Build Coastguard Worker result = torch.vmap(torch.ops._torch_testing.foo)(w, z=42) 3428*da0073e9SAndroid Build Coastguard Worker self.assertTrue(called) 3429*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, w * 2 * 3 * 42) 3430*da0073e9SAndroid Build Coastguard Worker 3431*da0073e9SAndroid Build Coastguard Worker def test_layout_constraint_tags(self): 3432*da0073e9SAndroid Build Coastguard Worker needs_fixed_stride_order = torch._C.Tag.needs_fixed_stride_order 3433*da0073e9SAndroid Build Coastguard Worker flexible_layout = torch._C.Tag.flexible_layout 3434*da0073e9SAndroid Build Coastguard Worker # (tags, the result of the tag inference) 3435*da0073e9SAndroid Build Coastguard Worker tests = [ 3436*da0073e9SAndroid Build Coastguard Worker ({needs_fixed_stride_order}, needs_fixed_stride_order), 3437*da0073e9SAndroid Build Coastguard Worker ({flexible_layout}, flexible_layout), 3438*da0073e9SAndroid Build Coastguard Worker # If no tags are provided, then the following is the default 3439*da0073e9SAndroid Build Coastguard Worker (set(), flexible_layout), 3440*da0073e9SAndroid Build Coastguard Worker # If multiple tags are provided, then we use the most constrained tag. 3441*da0073e9SAndroid Build Coastguard Worker ({flexible_layout, needs_fixed_stride_order}, needs_fixed_stride_order), 3442*da0073e9SAndroid Build Coastguard Worker ] 3443*da0073e9SAndroid Build Coastguard Worker from torch._inductor.lowering import get_layout_constraint_tag 3444*da0073e9SAndroid Build Coastguard Worker 3445*da0073e9SAndroid Build Coastguard Worker for tags, expected in tests: 3446*da0073e9SAndroid Build Coastguard Worker with torch.library._scoped_library("mylib", "FRAGMENT") as m: 3447*da0073e9SAndroid Build Coastguard Worker m.define("foobar(Tensor x) -> Tensor", tags=tags) 3448*da0073e9SAndroid Build Coastguard Worker result = get_layout_constraint_tag(torch.ops.mylib.foobar.default) 3449*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, expected) 3450*da0073e9SAndroid Build Coastguard Worker 3451*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") 3452*da0073e9SAndroid Build Coastguard Worker def test_library_register_vmap(self): 3453*da0073e9SAndroid Build Coastguard Worker for mode in ["function", "qualname", "opoverload", "c_opdef"]: 3454*da0073e9SAndroid Build Coastguard Worker 3455*da0073e9SAndroid Build Coastguard Worker @torch.library.custom_op("mylib::f", mutates_args=()) 3456*da0073e9SAndroid Build Coastguard Worker def f(x: Tensor, y: Tensor) -> Tensor: 3457*da0073e9SAndroid Build Coastguard Worker return x * y 3458*da0073e9SAndroid Build Coastguard Worker 3459*da0073e9SAndroid Build Coastguard Worker called = False 3460*da0073e9SAndroid Build Coastguard Worker 3461*da0073e9SAndroid Build Coastguard Worker def fvmap(info, in_dims, x, y): 3462*da0073e9SAndroid Build Coastguard Worker nonlocal called 3463*da0073e9SAndroid Build Coastguard Worker called = True 3464*da0073e9SAndroid Build Coastguard Worker x_bdim, y_bdim = in_dims 3465*da0073e9SAndroid Build Coastguard Worker x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1) 3466*da0073e9SAndroid Build Coastguard Worker y = y.movedim(y_bdim, -1) if y_bdim is not None else y.unsqueeze(-1) 3467*da0073e9SAndroid Build Coastguard Worker result = x * y 3468*da0073e9SAndroid Build Coastguard Worker result = result.movedim(-1, 0) 3469*da0073e9SAndroid Build Coastguard Worker return result, 0 3470*da0073e9SAndroid Build Coastguard Worker 3471*da0073e9SAndroid Build Coastguard Worker if mode == "function": 3472*da0073e9SAndroid Build Coastguard Worker torch.library.register_vmap(f, fvmap) 3473*da0073e9SAndroid Build Coastguard Worker elif mode == "qualname": 3474*da0073e9SAndroid Build Coastguard Worker torch.library.register_vmap("mylib::f", fvmap) 3475*da0073e9SAndroid Build Coastguard Worker elif mode == "opoverload": 3476*da0073e9SAndroid Build Coastguard Worker torch.library.register_vmap(torch.ops.mylib.f.default, fvmap) 3477*da0073e9SAndroid Build Coastguard Worker elif mode == "c_opdef": 3478*da0073e9SAndroid Build Coastguard Worker f.register_vmap(fvmap) 3479*da0073e9SAndroid Build Coastguard Worker 3480*da0073e9SAndroid Build Coastguard Worker x = torch.randn(2, 2) 3481*da0073e9SAndroid Build Coastguard Worker y = torch.randn(2, 2) 3482*da0073e9SAndroid Build Coastguard Worker 3483*da0073e9SAndroid Build Coastguard Worker result = torch.vmap(f)(x, y) 3484*da0073e9SAndroid Build Coastguard Worker self.assertTrue(called) 3485*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, x * y) 3486*da0073e9SAndroid Build Coastguard Worker 3487*da0073e9SAndroid Build Coastguard Worker called = False 3488*da0073e9SAndroid Build Coastguard Worker result = torch.vmap(f, out_dims=1)(x, y) 3489*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, (x * y).T) 3490*da0073e9SAndroid Build Coastguard Worker self.assertTrue(called) 3491*da0073e9SAndroid Build Coastguard Worker 3492*da0073e9SAndroid Build Coastguard Worker called = False 3493*da0073e9SAndroid Build Coastguard Worker result = torch.vmap(f, in_dims=1)(x, y) 3494*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, (x * y).T) 3495*da0073e9SAndroid Build Coastguard Worker self.assertTrue(called) 3496*da0073e9SAndroid Build Coastguard Worker 3497*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") 3498*da0073e9SAndroid Build Coastguard Worker def test_library_register_vmap_library_decorator(self): 3499*da0073e9SAndroid Build Coastguard Worker @torch.library.custom_op("mylib::f", mutates_args=()) 3500*da0073e9SAndroid Build Coastguard Worker def f(x: Tensor, y: Tensor) -> Tensor: 3501*da0073e9SAndroid Build Coastguard Worker return x * y 3502*da0073e9SAndroid Build Coastguard Worker 3503*da0073e9SAndroid Build Coastguard Worker called = False 3504*da0073e9SAndroid Build Coastguard Worker 3505*da0073e9SAndroid Build Coastguard Worker @torch.library.register_vmap("mylib::f") 3506*da0073e9SAndroid Build Coastguard Worker def fvmap(info, in_dims, x, y): 3507*da0073e9SAndroid Build Coastguard Worker nonlocal called 3508*da0073e9SAndroid Build Coastguard Worker called = True 3509*da0073e9SAndroid Build Coastguard Worker x_bdim, y_bdim = in_dims 3510*da0073e9SAndroid Build Coastguard Worker x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1) 3511*da0073e9SAndroid Build Coastguard Worker y = y.movedim(y_bdim, -1) if y_bdim is not None else y.unsqueeze(-1) 3512*da0073e9SAndroid Build Coastguard Worker result = x * y 3513*da0073e9SAndroid Build Coastguard Worker result = result.movedim(-1, 0) 3514*da0073e9SAndroid Build Coastguard Worker return result, 0 3515*da0073e9SAndroid Build Coastguard Worker 3516*da0073e9SAndroid Build Coastguard Worker x = torch.randn(2, 2) 3517*da0073e9SAndroid Build Coastguard Worker y = torch.randn(2, 2) 3518*da0073e9SAndroid Build Coastguard Worker 3519*da0073e9SAndroid Build Coastguard Worker result = torch.vmap(f)(x, y) 3520*da0073e9SAndroid Build Coastguard Worker self.assertTrue(called) 3521*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, x * y) 3522*da0073e9SAndroid Build Coastguard Worker 3523*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") 3524*da0073e9SAndroid Build Coastguard Worker def test_library_register_vmap_op_decorator(self): 3525*da0073e9SAndroid Build Coastguard Worker @torch.library.custom_op("mylib::f", mutates_args=()) 3526*da0073e9SAndroid Build Coastguard Worker def f(x: Tensor, y: Tensor) -> Tensor: 3527*da0073e9SAndroid Build Coastguard Worker return x * y 3528*da0073e9SAndroid Build Coastguard Worker 3529*da0073e9SAndroid Build Coastguard Worker called = False 3530*da0073e9SAndroid Build Coastguard Worker 3531*da0073e9SAndroid Build Coastguard Worker @f.register_vmap 3532*da0073e9SAndroid Build Coastguard Worker def fvmap(info, in_dims, x, y): 3533*da0073e9SAndroid Build Coastguard Worker nonlocal called 3534*da0073e9SAndroid Build Coastguard Worker called = True 3535*da0073e9SAndroid Build Coastguard Worker x_bdim, y_bdim = in_dims 3536*da0073e9SAndroid Build Coastguard Worker x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1) 3537*da0073e9SAndroid Build Coastguard Worker y = y.movedim(y_bdim, -1) if y_bdim is not None else y.unsqueeze(-1) 3538*da0073e9SAndroid Build Coastguard Worker result = x * y 3539*da0073e9SAndroid Build Coastguard Worker result = result.movedim(-1, 0) 3540*da0073e9SAndroid Build Coastguard Worker return result, 0 3541*da0073e9SAndroid Build Coastguard Worker 3542*da0073e9SAndroid Build Coastguard Worker x = torch.randn(2, 2) 3543*da0073e9SAndroid Build Coastguard Worker y = torch.randn(2, 2) 3544*da0073e9SAndroid Build Coastguard Worker 3545*da0073e9SAndroid Build Coastguard Worker result = torch.vmap(f)(x, y) 3546*da0073e9SAndroid Build Coastguard Worker self.assertTrue(called) 3547*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, x * y) 3548*da0073e9SAndroid Build Coastguard Worker 3549*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") 3550*da0073e9SAndroid Build Coastguard Worker def test_library_register_vmap_register_multiple_times(self): 3551*da0073e9SAndroid Build Coastguard Worker @torch.library.custom_op("mylib::f", mutates_args=()) 3552*da0073e9SAndroid Build Coastguard Worker def f(x: Tensor, y: Tensor) -> Tensor: 3553*da0073e9SAndroid Build Coastguard Worker return x * y 3554*da0073e9SAndroid Build Coastguard Worker 3555*da0073e9SAndroid Build Coastguard Worker called = False 3556*da0073e9SAndroid Build Coastguard Worker 3557*da0073e9SAndroid Build Coastguard Worker @f.register_vmap 3558*da0073e9SAndroid Build Coastguard Worker def fvmap(info, in_dims, x, y): 3559*da0073e9SAndroid Build Coastguard Worker nonlocal called 3560*da0073e9SAndroid Build Coastguard Worker called = True 3561*da0073e9SAndroid Build Coastguard Worker x_bdim, y_bdim = in_dims 3562*da0073e9SAndroid Build Coastguard Worker x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1) 3563*da0073e9SAndroid Build Coastguard Worker y = y.movedim(y_bdim, -1) if y_bdim is not None else y.unsqueeze(-1) 3564*da0073e9SAndroid Build Coastguard Worker result = x * y 3565*da0073e9SAndroid Build Coastguard Worker result = result.movedim(-1, 0) 3566*da0073e9SAndroid Build Coastguard Worker return result, 0 3567*da0073e9SAndroid Build Coastguard Worker 3568*da0073e9SAndroid Build Coastguard Worker x = torch.randn(2, 2) 3569*da0073e9SAndroid Build Coastguard Worker y = torch.randn(2, 2) 3570*da0073e9SAndroid Build Coastguard Worker 3571*da0073e9SAndroid Build Coastguard Worker result = torch.vmap(f)(x, y) 3572*da0073e9SAndroid Build Coastguard Worker self.assertTrue(called) 3573*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, x * y) 3574*da0073e9SAndroid Build Coastguard Worker called = False 3575*da0073e9SAndroid Build Coastguard Worker 3576*da0073e9SAndroid Build Coastguard Worker @f.register_vmap 3577*da0073e9SAndroid Build Coastguard Worker def fvmap2(info, in_dims, x, y): 3578*da0073e9SAndroid Build Coastguard Worker nonlocal called 3579*da0073e9SAndroid Build Coastguard Worker called = True 3580*da0073e9SAndroid Build Coastguard Worker x_bdim, y_bdim = in_dims 3581*da0073e9SAndroid Build Coastguard Worker x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1) 3582*da0073e9SAndroid Build Coastguard Worker y = y.movedim(y_bdim, -1) if y_bdim is not None else y.unsqueeze(-1) 3583*da0073e9SAndroid Build Coastguard Worker result = x + y 3584*da0073e9SAndroid Build Coastguard Worker result = result.movedim(-1, 0) 3585*da0073e9SAndroid Build Coastguard Worker return result, 0 3586*da0073e9SAndroid Build Coastguard Worker 3587*da0073e9SAndroid Build Coastguard Worker result = torch.vmap(f)(x, y) 3588*da0073e9SAndroid Build Coastguard Worker self.assertTrue(called) 3589*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, x + y) 3590*da0073e9SAndroid Build Coastguard Worker 3591*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") 3592*da0073e9SAndroid Build Coastguard Worker def test_library_register_vmap_register_multiple_times_2(self): 3593*da0073e9SAndroid Build Coastguard Worker @torch.library.custom_op("mylib::f", mutates_args=()) 3594*da0073e9SAndroid Build Coastguard Worker def f(x: Tensor, y: Tensor) -> Tensor: 3595*da0073e9SAndroid Build Coastguard Worker return x * y 3596*da0073e9SAndroid Build Coastguard Worker 3597*da0073e9SAndroid Build Coastguard Worker called = False 3598*da0073e9SAndroid Build Coastguard Worker 3599*da0073e9SAndroid Build Coastguard Worker @torch.library.register_vmap("mylib::f") 3600*da0073e9SAndroid Build Coastguard Worker def fvmap(info, in_dims, x, y): 3601*da0073e9SAndroid Build Coastguard Worker nonlocal called 3602*da0073e9SAndroid Build Coastguard Worker called = True 3603*da0073e9SAndroid Build Coastguard Worker x_bdim, y_bdim = in_dims 3604*da0073e9SAndroid Build Coastguard Worker x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1) 3605*da0073e9SAndroid Build Coastguard Worker y = y.movedim(y_bdim, -1) if y_bdim is not None else y.unsqueeze(-1) 3606*da0073e9SAndroid Build Coastguard Worker result = x * y 3607*da0073e9SAndroid Build Coastguard Worker result = result.movedim(-1, 0) 3608*da0073e9SAndroid Build Coastguard Worker return result, 0 3609*da0073e9SAndroid Build Coastguard Worker 3610*da0073e9SAndroid Build Coastguard Worker x = torch.randn(2, 2) 3611*da0073e9SAndroid Build Coastguard Worker y = torch.randn(2, 2) 3612*da0073e9SAndroid Build Coastguard Worker 3613*da0073e9SAndroid Build Coastguard Worker result = torch.vmap(f)(x, y) 3614*da0073e9SAndroid Build Coastguard Worker self.assertTrue(called) 3615*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, x * y) 3616*da0073e9SAndroid Build Coastguard Worker called = False 3617*da0073e9SAndroid Build Coastguard Worker 3618*da0073e9SAndroid Build Coastguard Worker @torch.library.register_vmap("mylib::f") 3619*da0073e9SAndroid Build Coastguard Worker def fvmap2(info, in_dims, x, y): 3620*da0073e9SAndroid Build Coastguard Worker nonlocal called 3621*da0073e9SAndroid Build Coastguard Worker called = True 3622*da0073e9SAndroid Build Coastguard Worker x_bdim, y_bdim = in_dims 3623*da0073e9SAndroid Build Coastguard Worker x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1) 3624*da0073e9SAndroid Build Coastguard Worker y = y.movedim(y_bdim, -1) if y_bdim is not None else y.unsqueeze(-1) 3625*da0073e9SAndroid Build Coastguard Worker result = x + y 3626*da0073e9SAndroid Build Coastguard Worker result = result.movedim(-1, 0) 3627*da0073e9SAndroid Build Coastguard Worker return result, 0 3628*da0073e9SAndroid Build Coastguard Worker 3629*da0073e9SAndroid Build Coastguard Worker result = torch.vmap(f)(x, y) 3630*da0073e9SAndroid Build Coastguard Worker self.assertTrue(called) 3631*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, x + y) 3632*da0073e9SAndroid Build Coastguard Worker 3633*da0073e9SAndroid Build Coastguard Worker 3634*da0073e9SAndroid Build Coastguard Workerclass MiniOpTestOther(CustomOpTestCaseBase): 3635*da0073e9SAndroid Build Coastguard Worker test_ns = "mini_op_test" 3636*da0073e9SAndroid Build Coastguard Worker 3637*da0073e9SAndroid Build Coastguard Worker def test_nonzero_again(self): 3638*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([0, 1, 2, 0, 0]) 3639*da0073e9SAndroid Build Coastguard Worker y = torch.ops.aten.nonzero.default(x) 3640*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y, torch.tensor([[1], [2]])) 3641*da0073e9SAndroid Build Coastguard Worker 3642*da0073e9SAndroid Build Coastguard Worker 3643*da0073e9SAndroid Build Coastguard Workeroptests.generate_opcheck_tests( 3644*da0073e9SAndroid Build Coastguard Worker MiniOpTest, 3645*da0073e9SAndroid Build Coastguard Worker ["aten", "mini_op_test"], 3646*da0073e9SAndroid Build Coastguard Worker get_file_path_2(os.path.dirname(__file__), "minioptest_failures_dict.json"), 3647*da0073e9SAndroid Build Coastguard Worker additional_decorators={ 3648*da0073e9SAndroid Build Coastguard Worker "test_pt2_compliant_tag_mini_op_test_no_abstract": [unittest.expectedFailure] 3649*da0073e9SAndroid Build Coastguard Worker }, 3650*da0073e9SAndroid Build Coastguard Worker test_utils=optests.generate_tests.DEPRECATED_DEFAULT_TEST_UTILS, 3651*da0073e9SAndroid Build Coastguard Worker) 3652*da0073e9SAndroid Build Coastguard Worker 3653*da0073e9SAndroid Build Coastguard Workeroptests.generate_opcheck_tests( 3654*da0073e9SAndroid Build Coastguard Worker MiniOpTestOther, 3655*da0073e9SAndroid Build Coastguard Worker ["aten", "mini_op_test"], 3656*da0073e9SAndroid Build Coastguard Worker get_file_path_2(os.path.dirname(__file__), "minioptest_failures_dict.json"), 3657*da0073e9SAndroid Build Coastguard Worker test_utils=optests.generate_tests.DEPRECATED_DEFAULT_TEST_UTILS, 3658*da0073e9SAndroid Build Coastguard Worker) 3659*da0073e9SAndroid Build Coastguard Worker 3660*da0073e9SAndroid Build Coastguard Worker 3661*da0073e9SAndroid Build Coastguard Workerclass TestGenerateOpcheckTests(CustomOpTestCaseBase): 3662*da0073e9SAndroid Build Coastguard Worker def test_MiniOpTest(self): 3663*da0073e9SAndroid Build Coastguard Worker for orig_test in ["test_mm", "test_nonzero"]: 3664*da0073e9SAndroid Build Coastguard Worker for ( 3665*da0073e9SAndroid Build Coastguard Worker test 3666*da0073e9SAndroid Build Coastguard Worker ) in torch.testing._internal.optests.generate_tests.DEFAULT_TEST_UTILS: 3667*da0073e9SAndroid Build Coastguard Worker expected_test = f"{test}__{orig_test}" 3668*da0073e9SAndroid Build Coastguard Worker self.assertTrue(hasattr(MiniOpTest, expected_test), msg=expected_test) 3669*da0073e9SAndroid Build Coastguard Worker 3670*da0073e9SAndroid Build Coastguard Worker def test_generate_repro_save_data(self): 3671*da0073e9SAndroid Build Coastguard Worker from torch.testing._internal.optests.generate_tests import generate_repro 3672*da0073e9SAndroid Build Coastguard Worker 3673*da0073e9SAndroid Build Coastguard Worker args = (torch.ones(2, 2),) 3674*da0073e9SAndroid Build Coastguard Worker kwargs = {"mat2": torch.zeros(2, 2)} 3675*da0073e9SAndroid Build Coastguard Worker actual = generate_repro( 3676*da0073e9SAndroid Build Coastguard Worker "test_schema", 3677*da0073e9SAndroid Build Coastguard Worker torch.ops.aten.sin.default, 3678*da0073e9SAndroid Build Coastguard Worker args, 3679*da0073e9SAndroid Build Coastguard Worker kwargs, 3680*da0073e9SAndroid Build Coastguard Worker save_data=True, 3681*da0073e9SAndroid Build Coastguard Worker dry_run=True, 3682*da0073e9SAndroid Build Coastguard Worker ) 3683*da0073e9SAndroid Build Coastguard Worker actual = re.sub(r"torch.load\(\".*\.pt\"\)", 'torch.load("repro.pt")', actual) 3684*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 3685*da0073e9SAndroid Build Coastguard Worker actual, 3686*da0073e9SAndroid Build Coastguard Worker """\ 3687*da0073e9SAndroid Build Coastguard Worker# ========================================================= 3688*da0073e9SAndroid Build Coastguard Worker# BEGIN REPRO SCRIPT 3689*da0073e9SAndroid Build Coastguard Worker# ========================================================= 3690*da0073e9SAndroid Build Coastguard Workerimport torch 3691*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.optests import opcheck 3692*da0073e9SAndroid Build Coastguard Worker 3693*da0073e9SAndroid Build Coastguard Worker# Make sure you have loaded the library that contains the op 3694*da0073e9SAndroid Build Coastguard Worker# via an import or torch.ops.load_library(...) 3695*da0073e9SAndroid Build Coastguard Workerop = torch.ops.aten.sin.default 3696*da0073e9SAndroid Build Coastguard Worker 3697*da0073e9SAndroid Build Coastguard Workerargs, kwargs = torch.load("repro.pt") 3698*da0073e9SAndroid Build Coastguard Workeropcheck(op, args, kwargs, test_utils="test_schema") 3699*da0073e9SAndroid Build Coastguard Worker# ========================================================= 3700*da0073e9SAndroid Build Coastguard Worker# END REPRO SCRIPT 3701*da0073e9SAndroid Build Coastguard Worker# ========================================================= 3702*da0073e9SAndroid Build Coastguard Worker""", 3703*da0073e9SAndroid Build Coastguard Worker ) 3704*da0073e9SAndroid Build Coastguard Worker 3705*da0073e9SAndroid Build Coastguard Worker def test_generate_repro_no_save_data(self): 3706*da0073e9SAndroid Build Coastguard Worker from torch.testing._internal.optests.generate_tests import generate_repro 3707*da0073e9SAndroid Build Coastguard Worker 3708*da0073e9SAndroid Build Coastguard Worker args = (torch.ones(2, 2),) 3709*da0073e9SAndroid Build Coastguard Worker kwargs = {"mat2": torch.zeros(2, 2)} 3710*da0073e9SAndroid Build Coastguard Worker actual = generate_repro( 3711*da0073e9SAndroid Build Coastguard Worker "test_schema", 3712*da0073e9SAndroid Build Coastguard Worker torch.ops.aten.sin.default, 3713*da0073e9SAndroid Build Coastguard Worker args, 3714*da0073e9SAndroid Build Coastguard Worker kwargs, 3715*da0073e9SAndroid Build Coastguard Worker save_data=False, 3716*da0073e9SAndroid Build Coastguard Worker dry_run=True, 3717*da0073e9SAndroid Build Coastguard Worker ) 3718*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 3719*da0073e9SAndroid Build Coastguard Worker actual, 3720*da0073e9SAndroid Build Coastguard Worker """\ 3721*da0073e9SAndroid Build Coastguard Worker# ========================================================= 3722*da0073e9SAndroid Build Coastguard Worker# BEGIN REPRO SCRIPT 3723*da0073e9SAndroid Build Coastguard Worker# ========================================================= 3724*da0073e9SAndroid Build Coastguard Workerimport torch 3725*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.optests import opcheck 3726*da0073e9SAndroid Build Coastguard Worker 3727*da0073e9SAndroid Build Coastguard Worker# Make sure you have loaded the library that contains the op 3728*da0073e9SAndroid Build Coastguard Worker# via an import or torch.ops.load_library(...) 3729*da0073e9SAndroid Build Coastguard Workerop = torch.ops.aten.sin.default 3730*da0073e9SAndroid Build Coastguard Worker 3731*da0073e9SAndroid Build Coastguard Worker# If you rerun your test with PYTORCH_OPCHECK_PRINT_BETTER_REPRO=1 3732*da0073e9SAndroid Build Coastguard Worker# we will fill them in same (args, kwargs) as in your test 3733*da0073e9SAndroid Build Coastguard Workerargs = () # args to the operator 3734*da0073e9SAndroid Build Coastguard Workerkwargs = {} # kwargs to the operator 3735*da0073e9SAndroid Build Coastguard Workeropcheck(op, args, kwargs, test_utils="test_schema") 3736*da0073e9SAndroid Build Coastguard Worker# ========================================================= 3737*da0073e9SAndroid Build Coastguard Worker# END REPRO SCRIPT 3738*da0073e9SAndroid Build Coastguard Worker# ========================================================= 3739*da0073e9SAndroid Build Coastguard Worker""", 3740*da0073e9SAndroid Build Coastguard Worker ) 3741*da0073e9SAndroid Build Coastguard Worker 3742*da0073e9SAndroid Build Coastguard Worker def test_failures_dict_validation(self): 3743*da0073e9SAndroid Build Coastguard Worker from torch.testing._internal.optests.generate_tests import ( 3744*da0073e9SAndroid Build Coastguard Worker FailuresDict, 3745*da0073e9SAndroid Build Coastguard Worker validate_failures_dict_structure, 3746*da0073e9SAndroid Build Coastguard Worker ) 3747*da0073e9SAndroid Build Coastguard Worker 3748*da0073e9SAndroid Build Coastguard Worker failures = { 3749*da0073e9SAndroid Build Coastguard Worker "mini_op_test::incorrect_schema": { 3750*da0073e9SAndroid Build Coastguard Worker "MiniOpTest.test_aot_dispatch_dynamic__test_delayed_error": { 3751*da0073e9SAndroid Build Coastguard Worker "comment": "", 3752*da0073e9SAndroid Build Coastguard Worker "status": "success", 3753*da0073e9SAndroid Build Coastguard Worker } 3754*da0073e9SAndroid Build Coastguard Worker } 3755*da0073e9SAndroid Build Coastguard Worker } 3756*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "got status=success"): 3757*da0073e9SAndroid Build Coastguard Worker validate_failures_dict_structure( 3758*da0073e9SAndroid Build Coastguard Worker FailuresDict("", failures), 3759*da0073e9SAndroid Build Coastguard Worker torch.testing._internal.optests.generate_tests.DEFAULT_TEST_UTILS, 3760*da0073e9SAndroid Build Coastguard Worker MiniOpTest, 3761*da0073e9SAndroid Build Coastguard Worker ) 3762*da0073e9SAndroid Build Coastguard Worker 3763*da0073e9SAndroid Build Coastguard Worker failures = { 3764*da0073e9SAndroid Build Coastguard Worker "mini_op_test::incorrect_schema": { 3765*da0073e9SAndroid Build Coastguard Worker "MiniOpTest.test_aot_dispatch__test_delayed_error": { 3766*da0073e9SAndroid Build Coastguard Worker "comment": "", 3767*da0073e9SAndroid Build Coastguard Worker "status": "xfail", 3768*da0073e9SAndroid Build Coastguard Worker }, 3769*da0073e9SAndroid Build Coastguard Worker } 3770*da0073e9SAndroid Build Coastguard Worker } 3771*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "should begin with one of"): 3772*da0073e9SAndroid Build Coastguard Worker validate_failures_dict_structure( 3773*da0073e9SAndroid Build Coastguard Worker FailuresDict("", failures), 3774*da0073e9SAndroid Build Coastguard Worker torch.testing._internal.optests.generate_tests.DEFAULT_TEST_UTILS, 3775*da0073e9SAndroid Build Coastguard Worker MiniOpTest, 3776*da0073e9SAndroid Build Coastguard Worker ) 3777*da0073e9SAndroid Build Coastguard Worker 3778*da0073e9SAndroid Build Coastguard Worker failures = { 3779*da0073e9SAndroid Build Coastguard Worker "mini_op_test::incorrect_schema": { 3780*da0073e9SAndroid Build Coastguard Worker "MiniOpTest.test_aot_dispatch_dynamic__test_delayed_error_nopenopenope": { 3781*da0073e9SAndroid Build Coastguard Worker "comment": "", 3782*da0073e9SAndroid Build Coastguard Worker "status": "xfail", 3783*da0073e9SAndroid Build Coastguard Worker }, 3784*da0073e9SAndroid Build Coastguard Worker } 3785*da0073e9SAndroid Build Coastguard Worker } 3786*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "does not exist on the TestCase"): 3787*da0073e9SAndroid Build Coastguard Worker validate_failures_dict_structure( 3788*da0073e9SAndroid Build Coastguard Worker FailuresDict("", failures), 3789*da0073e9SAndroid Build Coastguard Worker torch.testing._internal.optests.generate_tests.DEFAULT_TEST_UTILS, 3790*da0073e9SAndroid Build Coastguard Worker MiniOpTest, 3791*da0073e9SAndroid Build Coastguard Worker ) 3792*da0073e9SAndroid Build Coastguard Worker 3793*da0073e9SAndroid Build Coastguard Worker def test_dont_generate_decorator(self): 3794*da0073e9SAndroid Build Coastguard Worker self.assertTrue(hasattr(MiniOpTest, "test_dont_generate")) 3795*da0073e9SAndroid Build Coastguard Worker self.assertFalse(hasattr(MiniOpTest, "test_schema__test_dont_generate")) 3796*da0073e9SAndroid Build Coastguard Worker 3797*da0073e9SAndroid Build Coastguard Worker def test_opcheck(self): 3798*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3, requires_grad=True) 3799*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, "OpOverload"): 3800*da0073e9SAndroid Build Coastguard Worker torch.library.opcheck(torch.sin, (x,)) 3801*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, "test_utils to be subset of"): 3802*da0073e9SAndroid Build Coastguard Worker torch.library.opcheck(torch.ops.aten.sin.default, (x,), test_utils="blah") 3803*da0073e9SAndroid Build Coastguard Worker result = torch.library.opcheck(torch.ops.aten.sin.default, (x,)) 3804*da0073e9SAndroid Build Coastguard Worker 3805*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 3806*da0073e9SAndroid Build Coastguard Worker result, 3807*da0073e9SAndroid Build Coastguard Worker { 3808*da0073e9SAndroid Build Coastguard Worker "test_schema": "SUCCESS", 3809*da0073e9SAndroid Build Coastguard Worker "test_autograd_registration": "SUCCESS", 3810*da0073e9SAndroid Build Coastguard Worker "test_faketensor": "SUCCESS", 3811*da0073e9SAndroid Build Coastguard Worker "test_aot_dispatch_dynamic": "SUCCESS", 3812*da0073e9SAndroid Build Coastguard Worker }, 3813*da0073e9SAndroid Build Coastguard Worker ) 3814*da0073e9SAndroid Build Coastguard Worker 3815*da0073e9SAndroid Build Coastguard Worker result = torch.library.opcheck( 3816*da0073e9SAndroid Build Coastguard Worker torch.ops.aten.sin.default, (x,), test_utils="test_schema" 3817*da0073e9SAndroid Build Coastguard Worker ) 3818*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, {"test_schema": "SUCCESS"}) 3819*da0073e9SAndroid Build Coastguard Worker 3820*da0073e9SAndroid Build Coastguard Worker result = torch.library.opcheck( 3821*da0073e9SAndroid Build Coastguard Worker torch.ops.aten.sin.default, 3822*da0073e9SAndroid Build Coastguard Worker (x,), 3823*da0073e9SAndroid Build Coastguard Worker test_utils=["test_schema", "test_faketensor"], 3824*da0073e9SAndroid Build Coastguard Worker ) 3825*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 3826*da0073e9SAndroid Build Coastguard Worker result, 3827*da0073e9SAndroid Build Coastguard Worker { 3828*da0073e9SAndroid Build Coastguard Worker "test_schema": "SUCCESS", 3829*da0073e9SAndroid Build Coastguard Worker "test_faketensor": "SUCCESS", 3830*da0073e9SAndroid Build Coastguard Worker }, 3831*da0073e9SAndroid Build Coastguard Worker ) 3832*da0073e9SAndroid Build Coastguard Worker 3833*da0073e9SAndroid Build Coastguard Worker def test_opcheck_customopdef(self): 3834*da0073e9SAndroid Build Coastguard Worker sample_inputs = [ 3835*da0073e9SAndroid Build Coastguard Worker (torch.randn(3),), 3836*da0073e9SAndroid Build Coastguard Worker (torch.randn(3, requires_grad=True),), 3837*da0073e9SAndroid Build Coastguard Worker ] 3838*da0073e9SAndroid Build Coastguard Worker if torch.cuda.is_available(): 3839*da0073e9SAndroid Build Coastguard Worker sample_inputs.extend( 3840*da0073e9SAndroid Build Coastguard Worker [ 3841*da0073e9SAndroid Build Coastguard Worker (torch.randn(3, device="cuda"),), 3842*da0073e9SAndroid Build Coastguard Worker (torch.randn(3, device="cuda", requires_grad=True),), 3843*da0073e9SAndroid Build Coastguard Worker ] 3844*da0073e9SAndroid Build Coastguard Worker ) 3845*da0073e9SAndroid Build Coastguard Worker for args in sample_inputs: 3846*da0073e9SAndroid Build Coastguard Worker torch.library.opcheck(custom_op_db.numpy_cube, args) 3847*da0073e9SAndroid Build Coastguard Worker 3848*da0073e9SAndroid Build Coastguard Worker def test_is_inside_opcheck_mode(self): 3849*da0073e9SAndroid Build Coastguard Worker self.assertFalse(optests.is_inside_opcheck_mode()) 3850*da0073e9SAndroid Build Coastguard Worker with optests.generate_tests.OpCheckMode( 3851*da0073e9SAndroid Build Coastguard Worker ["foo"], "bar", lambda x: x, None, "baz", "brr" 3852*da0073e9SAndroid Build Coastguard Worker ): 3853*da0073e9SAndroid Build Coastguard Worker self.assertTrue(optests.is_inside_opcheck_mode()) 3854*da0073e9SAndroid Build Coastguard Worker 3855*da0073e9SAndroid Build Coastguard Worker def test_opcheck_bad_op(self): 3856*da0073e9SAndroid Build Coastguard Worker op = op_with_incorrect_schema(self, "foo") 3857*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3) 3858*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(Exception, "is not defined to alias output"): 3859*da0073e9SAndroid Build Coastguard Worker torch.library.opcheck(op, (x,)) 3860*da0073e9SAndroid Build Coastguard Worker 3861*da0073e9SAndroid Build Coastguard Worker result = torch.library.opcheck(op, (x,), raise_exception=False) 3862*da0073e9SAndroid Build Coastguard Worker self.assertTrue(isinstance(result["test_schema"], RuntimeError)) 3863*da0073e9SAndroid Build Coastguard Worker del result["test_schema"] 3864*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 3865*da0073e9SAndroid Build Coastguard Worker result, 3866*da0073e9SAndroid Build Coastguard Worker { 3867*da0073e9SAndroid Build Coastguard Worker "test_autograd_registration": "SUCCESS", 3868*da0073e9SAndroid Build Coastguard Worker "test_faketensor": "SUCCESS", 3869*da0073e9SAndroid Build Coastguard Worker "test_aot_dispatch_dynamic": "SUCCESS", 3870*da0073e9SAndroid Build Coastguard Worker }, 3871*da0073e9SAndroid Build Coastguard Worker ) 3872*da0073e9SAndroid Build Coastguard Worker 3873*da0073e9SAndroid Build Coastguard Worker def test_opcheck_does_not_require_extra_deps(self): 3874*da0073e9SAndroid Build Coastguard Worker # torch.testing._internal.common_utils comes with a lot of additional 3875*da0073e9SAndroid Build Coastguard Worker # test-time dependencies. Since opcheck is public API, it should be 3876*da0073e9SAndroid Build Coastguard Worker # usable only with pytorch install-time dependencies. 3877*da0073e9SAndroid Build Coastguard Worker cmd = [ 3878*da0073e9SAndroid Build Coastguard Worker sys.executable, 3879*da0073e9SAndroid Build Coastguard Worker "-c", 3880*da0073e9SAndroid Build Coastguard Worker "import torch; import sys; \ 3881*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3, requires_grad=True); \ 3882*da0073e9SAndroid Build Coastguard Worker torch.library.opcheck(torch.ops.aten.sin.default, (x,)); \ 3883*da0073e9SAndroid Build Coastguard Worker assert 'expecttest' not in sys.modules; \ 3884*da0073e9SAndroid Build Coastguard Worker assert 'torch.testing._internal.common_utils' not in sys.modules", 3885*da0073e9SAndroid Build Coastguard Worker ] 3886*da0073e9SAndroid Build Coastguard Worker subprocess.check_output(cmd, shell=False) 3887*da0073e9SAndroid Build Coastguard Worker 3888*da0073e9SAndroid Build Coastguard Worker 3889*da0073e9SAndroid Build Coastguard Workerclass TestTypeConversion(TestCase): 3890*da0073e9SAndroid Build Coastguard Worker """In infer_schema(), we try to suggest a correct type when the type annotation is wrong.""" 3891*da0073e9SAndroid Build Coastguard Worker 3892*da0073e9SAndroid Build Coastguard Worker def setUp(self): 3893*da0073e9SAndroid Build Coastguard Worker self.supported_base_types = [ 3894*da0073e9SAndroid Build Coastguard Worker int, 3895*da0073e9SAndroid Build Coastguard Worker float, 3896*da0073e9SAndroid Build Coastguard Worker bool, 3897*da0073e9SAndroid Build Coastguard Worker str, 3898*da0073e9SAndroid Build Coastguard Worker torch.device, 3899*da0073e9SAndroid Build Coastguard Worker torch.Tensor, 3900*da0073e9SAndroid Build Coastguard Worker torch.dtype, 3901*da0073e9SAndroid Build Coastguard Worker torch.types.Number, 3902*da0073e9SAndroid Build Coastguard Worker ] 3903*da0073e9SAndroid Build Coastguard Worker 3904*da0073e9SAndroid Build Coastguard Worker def test_simple_tuple(self): 3905*da0073e9SAndroid Build Coastguard Worker self.assertEqual(List, tuple_to_list(Tuple)) 3906*da0073e9SAndroid Build Coastguard Worker 3907*da0073e9SAndroid Build Coastguard Worker def test_supported_types(self): 3908*da0073e9SAndroid Build Coastguard Worker for t in self.supported_base_types: 3909*da0073e9SAndroid Build Coastguard Worker result_type = tuple_to_list(Tuple[t, t, t]) 3910*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result_type, List[t]) 3911*da0073e9SAndroid Build Coastguard Worker 3912*da0073e9SAndroid Build Coastguard Worker result_type = tuple_to_list(Tuple[t]) 3913*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result_type, List[t]) 3914*da0073e9SAndroid Build Coastguard Worker 3915*da0073e9SAndroid Build Coastguard Worker def test_optional(self): 3916*da0073e9SAndroid Build Coastguard Worker for t in self.supported_base_types: 3917*da0073e9SAndroid Build Coastguard Worker result_type = tuple_to_list(Tuple[t, Optional[t]]) 3918*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result_type, List[Optional[t]]) 3919*da0073e9SAndroid Build Coastguard Worker 3920*da0073e9SAndroid Build Coastguard Worker result_type = tuple_to_list(Tuple[t, t, Optional[t]]) 3921*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result_type, List[Optional[t]]) 3922*da0073e9SAndroid Build Coastguard Worker 3923*da0073e9SAndroid Build Coastguard Worker result_type = tuple_to_list(Tuple[t, ...]) 3924*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result_type, List[t]) 3925*da0073e9SAndroid Build Coastguard Worker 3926*da0073e9SAndroid Build Coastguard Worker def test_mixed_types(self): 3927*da0073e9SAndroid Build Coastguard Worker result_type = tuple_to_list(Tuple[int, float]) 3928*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result_type, List[typing.Union[int, float]]) 3929*da0073e9SAndroid Build Coastguard Worker 3930*da0073e9SAndroid Build Coastguard Worker result_type = tuple_to_list(Tuple[int, float, str]) 3931*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result_type, List[typing.Union[int, float, str]]) 3932*da0073e9SAndroid Build Coastguard Worker 3933*da0073e9SAndroid Build Coastguard Worker 3934*da0073e9SAndroid Build Coastguard Workeronly_for = ("cpu", "cuda") 3935*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestCustomOpTesting, globals(), only_for=only_for) 3936*da0073e9SAndroid Build Coastguard Workerinstantiate_parametrized_tests(TestCustomOp) 3937*da0073e9SAndroid Build Coastguard Workerinstantiate_parametrized_tests(TestCustomOpAPI) 3938*da0073e9SAndroid Build Coastguard Worker 3939*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 3940*da0073e9SAndroid Build Coastguard Worker run_tests() 3941