1from __future__ import annotations 2 3import dataclasses 4import typing 5import unittest 6from collections import defaultdict 7 8import yaml 9from tools.autograd import gen_autograd_functions, load_derivatives 10 11from torchgen import dest 12from torchgen.api.types import CppSignatureGroup, DispatcherSignature 13from torchgen.context import native_function_manager 14from torchgen.gen import ( 15 get_native_function_declarations, 16 get_native_function_schema_registrations, 17 LineLoader, 18 static_dispatch, 19) 20from torchgen.model import ( 21 BackendIndex, 22 BackendMetadata, 23 DispatchKey, 24 FunctionSchema, 25 Location, 26 NativeFunction, 27 OperatorName, 28) 29from torchgen.native_function_generation import add_generated_native_functions 30from torchgen.selective_build.selector import SelectiveBuilder 31 32 33class TestCreateDerivative(unittest.TestCase): 34 def test_named_grads(self) -> None: 35 schema = FunctionSchema.parse( 36 "func(Tensor a, Tensor b) -> (Tensor x, Tensor y)" 37 ) 38 native_function = dataclasses.replace(DEFAULT_NATIVE_FUNCTION, func=schema) 39 40 derivative = load_derivatives.create_derivative( 41 native_function, 42 formula="func_backward(grad_x, grad_y)", 43 var_names=(), 44 available_named_gradients=["grad_x", "grad_y"], 45 ) 46 self.assertSetEqual(derivative.named_gradients, {"grad_x", "grad_y"}) 47 48 def test_non_differentiable_output(self) -> None: 49 specification = "func(Tensor a, Tensor b) -> (Tensor x, bool y, Tensor z)" 50 schema = FunctionSchema.parse(specification) 51 native_function = dataclasses.replace(DEFAULT_NATIVE_FUNCTION, func=schema) 52 53 _, differentiability_info = load_derivatives.create_differentiability_info( 54 defn_dict={ 55 "name": specification, 56 "dispatch": {"Default": {"a": "grads[0]", "b": "grads[2]"}}, 57 }, 58 functions_by_signature={schema.signature(): [native_function]}, 59 functions_by_schema={specification: native_function}, 60 op_counter=typing.Counter[str](), 61 used_dispatch_keys=set(), 62 ) 63 64 self.assertSequenceEqual( 65 differentiability_info["Default"].available_named_gradients, 66 # grad_y is not present because y is a 67 # bool and thus not differentiable. 68 ["grad_x", "grad_z"], 69 ) 70 71 def test_indexed_grads(self) -> None: 72 schema = FunctionSchema.parse( 73 "func(Tensor a, Tensor b) -> (Tensor x, Tensor y)" 74 ) 75 native_function = dataclasses.replace(DEFAULT_NATIVE_FUNCTION, func=schema) 76 77 derivative = load_derivatives.create_derivative( 78 native_function, 79 formula="func_backward(grads[0], grads[1])", 80 var_names=(), 81 available_named_gradients=["grad_x", "grad_y"], 82 ) 83 self.assertSetEqual(derivative.named_gradients, set()) 84 85 def test_named_grads_and_indexed_grads(self) -> None: 86 specification = "func(Tensor a, Tensor b) -> (Tensor x, Tensor y)" 87 schema = FunctionSchema.parse(specification) 88 native_function = dataclasses.replace(DEFAULT_NATIVE_FUNCTION, func=schema) 89 90 with self.assertRaisesRegex( 91 RuntimeError, 'illegally mixes use of "grad_RETURN_NAME"' 92 ): 93 load_derivatives.create_differentiability_info( 94 defn_dict={ 95 "name": specification, 96 # Uh-oh, the derivatives reference gradients by 97 # name and by index. 98 "dispatch": { 99 "Default": { 100 "a": "grad_x", 101 "b": "grads[1]", 102 } 103 }, 104 }, 105 functions_by_signature={schema.signature(): [native_function]}, 106 functions_by_schema={specification: native_function}, 107 op_counter=typing.Counter[str](), 108 used_dispatch_keys=set(), 109 ) 110 111 112class TestGenAutogradFunctions(unittest.TestCase): 113 def test_non_differentiable_output_invalid_type(self) -> None: 114 specification = "func(Tensor a, Tensor b) -> (Tensor x, bool y, Tensor z)" 115 schema = FunctionSchema.parse(specification) 116 native_function = dataclasses.replace(DEFAULT_NATIVE_FUNCTION, func=schema) 117 118 _, differentiability_info = load_derivatives.create_differentiability_info( 119 defn_dict={ 120 "name": specification, 121 "dispatch": { 122 "Default": { 123 "a": "grad_x", 124 "b": "grad_z", 125 } 126 }, 127 }, 128 functions_by_signature={schema.signature(): [native_function]}, 129 functions_by_schema={specification: native_function}, 130 op_counter=typing.Counter[str](), 131 used_dispatch_keys=set(), 132 ) 133 definition = gen_autograd_functions.process_function( 134 differentiability_info["Default"], 135 gen_autograd_functions.FUNCTION_DEFINITION, 136 ) 137 # grad_z should map to grads[1], not grads[2] because output 1 138 # (y) is not differentiable. 139 assert "grad_z = grads[2]" not in definition 140 assert "grad_z = grads[1]" in definition 141 142 def test_non_differentiable_output_output_differentiability(self) -> None: 143 specification = "func(Tensor a, Tensor b) -> (Tensor x, Tensor y, Tensor z)" 144 schema = FunctionSchema.parse(specification) 145 native_function = dataclasses.replace(DEFAULT_NATIVE_FUNCTION, func=schema) 146 147 _, differentiability_info = load_derivatives.create_differentiability_info( 148 defn_dict={ 149 "name": specification, 150 "dispatch": { 151 "Default": { 152 "a": "grad_x", 153 "b": "grad_z", 154 }, 155 "AutogradNestedTensor": { 156 "a": "grad_z", 157 "b": "grad_x", 158 }, 159 }, 160 "output_differentiability": [True, False, True], 161 }, 162 functions_by_signature={schema.signature(): [native_function]}, 163 functions_by_schema={specification: native_function}, 164 op_counter=typing.Counter[str](), 165 used_dispatch_keys=set(), 166 ) 167 default_definition = gen_autograd_functions.process_function( 168 differentiability_info["Default"], 169 gen_autograd_functions.FUNCTION_DEFINITION, 170 ) 171 # grad_z should map to grads[1], not grads[2] because output 1 172 # (y) is not differentiable. 173 assert "grad_z = grads[2]" not in default_definition 174 assert "grad_z = grads[1]" in default_definition 175 176 nested_tensor_definition = gen_autograd_functions.process_function( 177 differentiability_info["AutogradNestedTensor"], 178 gen_autograd_functions.FUNCTION_DEFINITION, 179 ) 180 assert "grad_z = grads[2]" not in nested_tensor_definition 181 assert "grad_z = grads[1]" in nested_tensor_definition 182 183 def test_register_bogus_dispatch_key(self) -> None: 184 specification = "func(Tensor a, Tensor b) -> (Tensor x, bool y, Tensor z)" 185 schema = FunctionSchema.parse(specification) 186 native_function = dataclasses.replace(DEFAULT_NATIVE_FUNCTION, func=schema) 187 188 with self.assertRaisesRegex( 189 RuntimeError, 190 "Invalid dispatch key AutogradRandomTensor in derivatives.yaml for", 191 ): 192 load_derivatives.create_differentiability_info( 193 defn_dict={ 194 "name": specification, 195 "dispatch": { 196 "Default": { 197 "a": "grad_x", 198 "b": "grad_z", 199 }, 200 "AutogradRandomTensor": { 201 "a": "grad_x", 202 "b": "grad_z", 203 }, 204 }, 205 }, 206 functions_by_signature={schema.signature(): [native_function]}, 207 functions_by_schema={specification: native_function}, 208 op_counter=typing.Counter[str](), 209 used_dispatch_keys=set(), 210 ) 211 212 213class TestGenSchemaRegistration(unittest.TestCase): 214 def setUp(self) -> None: 215 self.selector = SelectiveBuilder.get_nop_selector() 216 self.custom_native_function, _ = NativeFunction.from_yaml( 217 {"func": "custom::func() -> bool"}, 218 loc=Location(__file__, 1), 219 valid_tags=set(), 220 ) 221 ( 222 self.fragment_custom_native_function, 223 _, 224 ) = NativeFunction.from_yaml( 225 {"func": "quantized_decomposed::func() -> bool"}, 226 loc=Location(__file__, 1), 227 valid_tags=set(), 228 ) 229 230 def test_default_namespace_schema_registration_code_valid(self) -> None: 231 native_functions = [DEFAULT_NATIVE_FUNCTION] 232 registrations, _ = get_native_function_schema_registrations( 233 native_functions=native_functions, 234 schema_selector=self.selector, 235 ) 236 self.assertEqual(registrations, ['m.def("func() -> bool", {});\n']) 237 238 def test_custom_namespace_schema_registration_code_valid(self) -> None: 239 _, registrations = get_native_function_schema_registrations( 240 native_functions=[self.custom_native_function], 241 schema_selector=self.selector, 242 ) 243 self.assertEqual( 244 registrations, 245 """ 246TORCH_LIBRARY(custom, m) { 247 m.def("func() -> bool", {}); 248 249};""", 250 ) 251 252 def test_fragment_custom_namespace_schema_registration_code_valid(self) -> None: 253 """Sometimes we want to extend an existing namespace, for example quantized 254 namespace, which is already defined in native/quantized/library.cpp 255 """ 256 _, registrations = get_native_function_schema_registrations( 257 native_functions=[self.fragment_custom_native_function], 258 schema_selector=self.selector, 259 ) 260 self.assertEqual( 261 registrations, 262 """ 263TORCH_LIBRARY_FRAGMENT(quantized_decomposed, m) { 264 m.def("func() -> bool", {}); 265 266};""", 267 ) 268 269 def test_mixed_namespace_schema_registration_code_valid(self) -> None: 270 ( 271 aten_registrations, 272 custom_registrations, 273 ) = get_native_function_schema_registrations( 274 native_functions=[DEFAULT_NATIVE_FUNCTION, self.custom_native_function], 275 schema_selector=self.selector, 276 ) 277 self.assertEqual(aten_registrations, ['m.def("func() -> bool", {});\n']) 278 self.assertEqual( 279 custom_registrations, 280 """ 281TORCH_LIBRARY(custom, m) { 282 m.def("func() -> bool", {}); 283 284};""", 285 ) 286 287 def test_3_namespaces_schema_registration_code_valid(self) -> None: 288 custom2_native_function, _ = NativeFunction.from_yaml( 289 {"func": "custom2::func() -> bool"}, 290 loc=Location(__file__, 1), 291 valid_tags=set(), 292 ) 293 ( 294 aten_registrations, 295 custom_registrations, 296 ) = get_native_function_schema_registrations( 297 native_functions=[ 298 DEFAULT_NATIVE_FUNCTION, 299 self.custom_native_function, 300 custom2_native_function, 301 ], 302 schema_selector=self.selector, 303 ) 304 self.assertEqual(aten_registrations, ['m.def("func() -> bool", {});\n']) 305 self.assertEqual( 306 custom_registrations, 307 """ 308TORCH_LIBRARY(custom, m) { 309 m.def("func() -> bool", {}); 310 311}; 312TORCH_LIBRARY(custom2, m) { 313 m.def("func() -> bool", {}); 314 315};""", 316 ) 317 318 319class TestGenNativeFunctionDeclaration(unittest.TestCase): 320 def setUp(self) -> None: 321 self.op_1_native_function, op_1_backend_index = NativeFunction.from_yaml( 322 {"func": "op_1() -> bool", "dispatch": {"CPU": "kernel_1"}}, 323 loc=Location(__file__, 1), 324 valid_tags=set(), 325 ) 326 self.op_2_native_function, op_2_backend_index = NativeFunction.from_yaml( 327 { 328 "func": "op_2() -> bool", 329 "dispatch": {"CPU": "kernel_2", "QuantizedCPU": "custom::kernel_3"}, 330 }, 331 loc=Location(__file__, 1), 332 valid_tags=set(), 333 ) 334 335 backend_indices: dict[DispatchKey, dict[OperatorName, BackendMetadata]] = { 336 DispatchKey.CPU: {}, 337 DispatchKey.QuantizedCPU: {}, 338 } 339 BackendIndex.grow_index(backend_indices, op_1_backend_index) 340 BackendIndex.grow_index(backend_indices, op_2_backend_index) 341 self.backend_indices = { 342 k: BackendIndex( 343 dispatch_key=k, 344 use_out_as_primary=True, 345 external=False, 346 device_guard=False, 347 index=backend_indices[k], 348 ) 349 for k in backend_indices 350 } 351 352 def test_native_function_declaration_1_op_2_ns_error(self) -> None: 353 with self.assertRaises(AssertionError): 354 get_native_function_declarations( 355 grouped_native_functions=[ 356 self.op_1_native_function, 357 self.op_2_native_function, 358 ], 359 backend_indices=self.backend_indices, 360 native_function_decl_gen=dest.compute_native_function_declaration, 361 ) 362 363 def test_native_function_declaration_1_op_1_ns_valid(self) -> None: 364 self.assertIsInstance(self.op_1_native_function, NativeFunction) 365 declaration = get_native_function_declarations( 366 grouped_native_functions=[ 367 self.op_1_native_function, 368 ], 369 backend_indices=self.backend_indices, 370 native_function_decl_gen=dest.compute_native_function_declaration, 371 ) 372 target = """ 373namespace at { 374namespace native { 375TORCH_API bool kernel_1(); 376} // namespace native 377} // namespace at 378 """ 379 self.assertEqual("\n".join(declaration), target) 380 381 382# Test for native_function_generation 383class TestNativeFunctionGeneratrion(unittest.TestCase): 384 def setUp(self) -> None: 385 self.native_functions: list[NativeFunction] = [] 386 self.backend_indices: dict[ 387 DispatchKey, dict[OperatorName, BackendMetadata] 388 ] = defaultdict(dict) 389 yaml_entry = """ 390- func: op(Tensor self) -> Tensor 391 dispatch: 392 CompositeExplicitAutograd: op 393 autogen: op.out 394 """ 395 es = yaml.load(yaml_entry, Loader=LineLoader) 396 self.one_return_func, m = NativeFunction.from_yaml( 397 es[0], loc=Location(__file__, 1), valid_tags=set() 398 ) 399 400 BackendIndex.grow_index(self.backend_indices, m) 401 402 self.two_returns_func, two_returns_backend_index = NativeFunction.from_yaml( 403 { 404 "func": "op_2() -> (Tensor, Tensor)", 405 "dispatch": {"CPU": "kernel_1"}, 406 "autogen": "op_2.out", 407 }, 408 loc=Location(__file__, 1), 409 valid_tags=set(), 410 ) 411 BackendIndex.grow_index(self.backend_indices, two_returns_backend_index) 412 413 def test_functional_variant_autogen_out_variant(self) -> None: 414 native_functions = [self.one_return_func] 415 add_generated_native_functions(native_functions, self.backend_indices) 416 self.assertEqual(len(native_functions), 2) 417 self.assertEqual( 418 str(native_functions[1].func), 419 "op.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", 420 ) 421 op_name = native_functions[1].func.name 422 backend_metadata = self.backend_indices[DispatchKey.CompositeExplicitAutograd][ 423 op_name 424 ] 425 self.assertEqual(backend_metadata.kernel, "op_out") 426 427 def test_functional_variant_autogen_out_variant_two_returns(self) -> None: 428 native_functions = [self.two_returns_func] 429 add_generated_native_functions(native_functions, self.backend_indices) 430 self.assertEqual(len(native_functions), 2) 431 self.assertEqual( 432 str(native_functions[1].func), 433 "op_2.out(*, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))", 434 ) 435 op_name = native_functions[1].func.name 436 backend_metadata = self.backend_indices[DispatchKey.CompositeExplicitAutograd][ 437 op_name 438 ] 439 self.assertEqual(backend_metadata.kernel, "op_2_out") 440 441 442# Test for static_dispatch 443class TestStaticDispatchGeneratrion(unittest.TestCase): 444 def setUp(self) -> None: 445 self.backend_indices: dict[ 446 DispatchKey, dict[OperatorName, BackendMetadata] 447 ] = defaultdict(dict) 448 yaml_entry = """ 449- func: op.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) 450 dispatch: 451 CompositeExplicitAutograd: op 452 """ 453 es = yaml.load(yaml_entry, Loader=LineLoader) 454 self.one_return_func, m = NativeFunction.from_yaml( 455 es[0], loc=Location(__file__, 1), valid_tags=set() 456 ) 457 458 BackendIndex.grow_index(self.backend_indices, m) 459 dispatch_key = DispatchKey.CompositeExplicitAutograd 460 self.assertTrue(dispatch_key in self.backend_indices) 461 self.indices = [ 462 BackendIndex( 463 dispatch_key=dispatch_key, 464 use_out_as_primary=True, 465 external=False, 466 device_guard=False, 467 index=self.backend_indices[dispatch_key], 468 ) 469 ] 470 471 def test_op_with_1_backend_generates_static_dispatch(self) -> None: 472 disp_sig = DispatcherSignature.from_schema(self.one_return_func.func) 473 with native_function_manager(self.one_return_func): 474 out = static_dispatch( 475 sig=disp_sig, 476 f=self.one_return_func, 477 backend_indices=self.indices, 478 ) 479 self.assertEqual( 480 out, "return at::compositeexplicitautograd::op_out(out, self);" 481 ) 482 483 def test_op_with_cpp_sig_generates_static_dispatch(self) -> None: 484 sig_group = CppSignatureGroup.from_native_function( 485 self.one_return_func, 486 method=False, 487 fallback_binding=self.one_return_func.manual_cpp_binding, 488 ) 489 # cpp signature puts out at the front 490 with native_function_manager(self.one_return_func): 491 out = static_dispatch( 492 sig=sig_group.signature, 493 f=self.one_return_func, 494 backend_indices=self.indices, 495 ) 496 self.assertEqual( 497 out, "return at::compositeexplicitautograd::op_out(out, self);" 498 ) 499 500 501# Represents the most basic NativeFunction. Use dataclasses.replace() 502# to edit for use. 503DEFAULT_NATIVE_FUNCTION, _ = NativeFunction.from_yaml( 504 {"func": "func() -> bool"}, 505 loc=Location(__file__, 1), 506 valid_tags=set(), 507) 508 509 510if __name__ == "__main__": 511 unittest.main() 512