1# Parses derivatives.yaml into autograd functions 2# 3# Each autograd function is represented by `DifferentiabilityInfo` containing 4# a list of `Derivative`. See `torchgen.api.autograd` for the data models. 5 6from __future__ import annotations 7 8import re 9from collections import defaultdict 10from typing import Any, Counter, Dict, Sequence, Set, Tuple 11 12import yaml 13 14from torchgen.api import cpp 15from torchgen.api.autograd import ( 16 Derivative, 17 DifferentiabilityInfo, 18 ForwardDerivative, 19 SavedAttribute, 20) 21from torchgen.api.types import ( 22 BaseCType, 23 Binding, 24 boolT, 25 CppSignatureGroup, 26 layoutT, 27 longT, 28 NamedCType, 29 OptionalCType, 30 scalarTypeT, 31 SpecialArgName, 32 stringT, 33 symIntArrayRefT, 34 SymIntT, 35 tensorGeometryT, 36 tensorOptionsT, 37 typeAndSizeT, 38 VectorCType, 39) 40from torchgen.context import with_native_function 41from torchgen.gen import get_grouped_by_view_native_functions, parse_native_yaml 42from torchgen.model import ( 43 AUTOGRAD_KEYS, 44 FunctionSchema, 45 NativeFunction, 46 NativeFunctionsViewGroup, 47 OperatorName, 48 SchemaKind, 49 Type, 50 Variant, 51) 52from torchgen.utils import concatMap, IDENT_REGEX, split_name_params 53from torchgen.yaml_utils import YamlLoader 54 55 56DerivativeRet = Tuple[Dict[FunctionSchema, Dict[str, DifferentiabilityInfo]], Set[str]] 57 58_GLOBAL_LOAD_DERIVATIVE_CACHE: dict[tuple[str, str], DerivativeRet] = {} 59 60_VALID_AUTOGRAD_KEYS = set(AUTOGRAD_KEYS) 61 62 63# This function directly adds per-dispatchkey derivative entries for {view}_copy variants of each view op. 64# Since every {view} and {view}_copy op shares the same derivative formula, 65# we generate them here instead of duplicating them in the yaml. 66# See Note [Codegen'd {view}_copy Operators] 67def add_view_copy_derivatives( 68 infos: dict[FunctionSchema, dict[str, DifferentiabilityInfo]], 69 view_groups: list[NativeFunctionsViewGroup], 70) -> None: 71 # Get the map from each view op's name to its corresponding view group 72 view_name_to_group: dict[OperatorName, NativeFunctionsViewGroup] = { 73 g.view.func.name: g for g in view_groups 74 } 75 76 view_infos = {} 77 78 for info_dispatch_dict in infos.values(): 79 # maybe_view_group only needs to be calculated once per info_dispatch_dict 80 maybe_view_group = None 81 view_copy_differentiability_infos = {} 82 for dispatch_key, info in info_dispatch_dict.items(): 83 maybe_view_group = view_name_to_group.get(info.func.func.name, None) 84 if maybe_view_group is not None and maybe_view_group.view_copy is not None: 85 view_copy_info = info.create_view_copy_from_view_derivative( 86 maybe_view_group 87 ) 88 if view_copy_info is not None: 89 fn_schema = view_copy_info.func.func 90 view_copy_differentiability_infos[dispatch_key] = view_copy_info 91 else: 92 break 93 # prefer manually-defined derivatives if any 94 if len(view_copy_differentiability_infos) > 0 and fn_schema not in infos: 95 assert fn_schema is not None 96 view_infos[fn_schema] = view_copy_differentiability_infos 97 98 infos.update(view_infos) 99 100 101def load_derivatives( 102 derivatives_yaml_path: str, native_yaml_path: str, tags_yaml_path: str 103) -> DerivativeRet: 104 # Do some caching as this is a deterministic function 105 global _GLOBAL_LOAD_DERIVATIVE_CACHE 106 key = (derivatives_yaml_path, native_yaml_path) 107 if key not in _GLOBAL_LOAD_DERIVATIVE_CACHE: 108 with open(derivatives_yaml_path) as f: 109 definitions = yaml.load(f, Loader=YamlLoader) 110 111 funcs = parse_native_yaml(native_yaml_path, tags_yaml_path).native_functions 112 # From the parsed native functions, separate out the (generated) view_copy functions, 113 # so we can generate derivatives for them separately. 114 native_functions_with_view_groups = get_grouped_by_view_native_functions(funcs) 115 native_functions = concatMap( 116 lambda g: [g] 117 if isinstance(g, NativeFunction) 118 else list(g.functions(include_copy=True)), 119 native_functions_with_view_groups, 120 ) 121 view_groups = [ 122 g 123 for g in native_functions_with_view_groups 124 if isinstance(g, NativeFunctionsViewGroup) 125 ] 126 127 # What's the difference between function schema v.s. signature? 128 # function schema is the complete declaration including mutability annotation / default value and etc. 129 # signature is the canonical schema for a group of functions (in-place/out/functional variants) 130 # that are semantically related. 131 functions_by_signature: dict[ 132 FunctionSchema, list[NativeFunction] 133 ] = defaultdict(list) 134 functions_by_schema: dict[str, NativeFunction] = {} 135 for function in native_functions: 136 functions_by_signature[function.func.signature()].append(function) 137 assert str(function.func) not in functions_by_schema 138 functions_by_schema[str(function.func)] = function 139 140 # Keep track of how many of which ops we've seen so we can 141 # disambiguate them with a numeric suffix. 142 op_counter = Counter[str]() 143 144 # infos is a dict that maps FunctionSchema -> a dict of per dispatch key DifferentiabilityInfos 145 # this is useful because in tools/autograd/gen_autograd.py:match_differentiability_info 146 # we ultimately need to categorize the DifferentiabilityInfos by FunctionSchema 147 infos: dict[FunctionSchema, dict[str, DifferentiabilityInfo]] = {} 148 used_dispatch_keys: set[str] = set() 149 for defn_dict in definitions: 150 # Ensure that the old derivatives.yaml schema with no dispatch key can be loaded. 151 if "dispatch" not in defn_dict: 152 specification = defn_dict.pop("name") 153 output_differentiability = defn_dict.pop( 154 "output_differentiability", None 155 ) 156 defn_dict = {"name": specification, "dispatch": {"Default": defn_dict}} 157 if output_differentiability: 158 defn_dict["output_differentiability"] = output_differentiability 159 name, per_dispatch_diffinfos = create_differentiability_info( 160 defn_dict, 161 functions_by_signature, 162 functions_by_schema, 163 op_counter, 164 used_dispatch_keys, 165 ) 166 infos[name] = per_dispatch_diffinfos 167 168 add_view_copy_derivatives(infos, view_groups) 169 170 # cache both loaded infos as well a a set of all the dispatch_keys/aliases 171 # that appear in derivatives.yaml. used_dispatch_keys is useful for generating 172 # VariableType.cpp where we need a TORCH_LIBRARY_IMPL for every autograd dispatch key used 173 _GLOBAL_LOAD_DERIVATIVE_CACHE[key] = infos, used_dispatch_keys 174 175 return _GLOBAL_LOAD_DERIVATIVE_CACHE[key] 176 177 178# TODO: Why is this going through CppSignatureGroup, that doesn't make sense... 179@with_native_function 180def cpp_arguments(f: NativeFunction) -> Sequence[Binding]: 181 sigs = CppSignatureGroup.from_native_function(f, method=False) 182 if sigs.symint_signature is not None: 183 return sigs.symint_signature.arguments() 184 else: 185 return sigs.signature.arguments() 186 187 188def create_derivative( 189 f: NativeFunction, 190 formula: str, 191 var_names: tuple[str, ...], 192 available_named_gradients: Sequence[str], 193) -> Derivative: 194 original_formula = formula 195 arguments: list[NamedCType] = [ 196 a.nctype.remove_const_ref() for a in cpp_arguments(f) 197 ] 198 199 return_names = tuple(n if n != "self" else "result" for n in cpp.return_names(f)) 200 return_types = tuple( 201 cpp.return_type(r, symint=True).remove_const_ref() for r in f.func.returns 202 ) 203 204 named_returns = [ 205 NamedCType(name, type) for name, type in zip(return_names, return_types) 206 ] 207 208 formula, saved_inputs = saved_variables(formula, arguments, var_names) 209 formula, saved_outputs = saved_variables(formula, named_returns, var_names) 210 211 used_named_gradients = { 212 name 213 for name in available_named_gradients 214 if re.search(IDENT_REGEX.format(name), formula) 215 } 216 217 # Check that the referenced derivatives in the formula are in bounds 218 for i in used_gradient_indices(formula): 219 if i >= len(f.func.returns): 220 raise RuntimeError( 221 f"Out of bounds grads access: derivative formula for {cpp.name(f.func)} " 222 f"used grads[{i}], but the forward only returns {len(f.func.returns)} outputs." 223 ) 224 225 return Derivative( 226 formula=formula, 227 original_formula=original_formula, 228 var_names=var_names, 229 saved_inputs=saved_inputs, 230 saved_outputs=saved_outputs, 231 named_gradients=used_named_gradients, 232 ) 233 234 235def create_forward_derivative( 236 f: NativeFunction, formula: str, names: tuple[str, ...] 237) -> ForwardDerivative: 238 var_names = names 239 var_types: tuple[Type, ...] | None = None 240 for r in f.func.returns: 241 if r.name in var_names: 242 if var_types is None: 243 var_types = () 244 var_types = var_types + (r.type,) 245 246 # Handle default return names 247 if var_types is None: 248 if var_names == ("result",): 249 assert len(f.func.returns) == 1 250 var_types = (f.func.returns[0].type,) 251 else: 252 for var_name in var_names: 253 res = re.findall(r"^result(\d+)$", var_name) 254 if len(res) == 1: 255 if var_types is None: 256 var_types = () 257 arg_idx = int(res[0]) 258 var_types = var_types + (f.func.returns[arg_idx].type,) 259 260 assert var_types is not None, "No matching output for forward derivative definition" 261 return ForwardDerivative( 262 formula=formula, 263 var_names=var_names, 264 var_types=var_types, 265 required_inputs_fw_grad=None, 266 required_inputs_primal=None, 267 required_original_self_value=False, 268 is_reusing_outplace_formula=False, 269 ) 270 271 272def postprocess_forward_derivatives( 273 f: NativeFunction, 274 defn_name: str, 275 all_arg_names: list[str], 276 derivatives: list[Derivative], 277 forward_derivatives: list[ForwardDerivative], 278 args_with_derivatives: Sequence[Binding], 279) -> list[ForwardDerivative]: 280 def find_required_inputs(formula: str, postfix: str) -> tuple[str, ...]: 281 is_foreach = f.func.name.name.base.startswith("_foreach_") 282 required_inputs = set() 283 for arg in args_with_derivatives: 284 if ( 285 arg.type in ("at::TensorList", "const at::ITensorListRef &") 286 and not is_foreach 287 ): 288 # The functions taking TensorList handle everything internally 289 continue 290 arg_name = arg.name 291 292 found = re.search(IDENT_REGEX.format(arg_name), formula) 293 if found: 294 raise RuntimeError( 295 f"The forward formula for {defn_name} is using the base name of the {arg_name} " 296 f"argument which is ambiguous. You should use {arg_name}_p to access the primal " 297 f"value and {arg_name}_t to access the tangent." 298 ) 299 300 found = re.search(IDENT_REGEX.format(arg_name + postfix), formula) 301 if found: 302 required_inputs.add(arg_name) 303 304 return tuple(required_inputs) 305 306 updated_derivatives: list[ForwardDerivative] = [] 307 308 for defn in forward_derivatives: 309 formula = defn.formula 310 required_inputs_tangent = find_required_inputs(formula, "_t") 311 if formula == "auto_element_wise": 312 assert ( 313 f.func.kind() != SchemaKind.inplace 314 ), f"Cannot use auto_element_wise with {f.func.name} because it is an in-place variant" 315 if ( 316 (not len(args_with_derivatives) == 1) 317 or len(forward_derivatives) > 1 318 or len(forward_derivatives[0].var_names) > 1 319 ): 320 raise RuntimeError( 321 f"Derivative definition of {defn_name} in derivatives.yaml defines the " 322 "forward definition of gradient as element_wise but this only " 323 "works for functions with a single differentiable input and a " 324 "single differentiable output." 325 ) 326 if not len(derivatives) == 1: 327 raise RuntimeError( 328 f"Derivative definition of {defn_name} in derivatives.yaml defines the " 329 "forward definition of gradient as element_wise but it does not " 330 "defines the gradient formula for its argument which is required." 331 ) 332 # This transformation is based on the observation that for element-wise functions, the Jacobian 333 # matrix is diagonal and thus doing J * v is the same as (v^T J)^T (in practice, we ignore the transpositions) 334 # For the complex case, we use hermitian transpose and get (v.conj() J).conj() 335 # So here we are going to re-use the backward formula and replace two things: 336 # 1) all occurrences of "grad" with "foo_t.conj()", where foo is the name of the unique differentiable input. 337 # 2) all usage of an original input "foo" with its primal value "foo_p". 338 # 3) conjugate the final result 339 # For example, for abs, the backward formula is: 340 # grad * self.sgn() 341 # And this function generates a forward formula that is: 342 # (self_t.conj() * self_p.sgn()).conj() 343 344 backward_formula = derivatives[0].original_formula 345 input_name = args_with_derivatives[0].name 346 347 # Do replacement 1) of the grad 348 def repl(m: Any) -> str: 349 return f"{m.group(1)}{input_name}_t.conj(){m.group(2)}" 350 351 fw_formula = re.sub(IDENT_REGEX.format("grad"), repl, backward_formula) 352 353 # Do replacement 2) of the input variables 354 for arg in args_with_derivatives: 355 arg_name = arg.name 356 357 def repl(m: Any) -> str: 358 return f"{m.group(1)}{arg_name}_p{m.group(2)}" 359 360 fw_formula = re.sub(IDENT_REGEX.format(arg_name), repl, fw_formula) 361 362 # Do the final conjugate 3) 363 fw_formula = f"({fw_formula}).conj()" 364 365 # Since there is a single differentiable inputs and we necessarily need its tangent we can 366 # simply require all differentiable input's tangent. 367 required_inputs_tangent = tuple(all_arg_names) 368 formula = fw_formula 369 elif formula == "auto_linear": 370 if ( 371 len(forward_derivatives) > 1 372 or len(forward_derivatives[0].var_names) > 1 373 ): 374 raise RuntimeError( 375 f"Derivative definition of {defn_name} in derivatives.yaml defines the " 376 "forward definition of gradient as linear but this only works " 377 "for functions with a single differentiable output." 378 ) 379 # This transformation is based on the observation that linear functions can be written as: 380 # y = f(x) = A * x 381 # For some matrix A and the Jacobian of the function f is also A. 382 # So doing J * v = A * v = f(v). 383 # Hence to do the jvp, we simply need to evaluate the function at the point v instead of x. 384 # We do this by calling the forward again by replacing any occurrence of the differentiable 385 # input "foo" by it's tangent "foo_t". 386 # Note that multiple inputs are not a problem as long as the function is truly linear wrt to 387 # the vector where all the differentiable inputs are stacked. 388 389 diff_arg_names = [arg.name for arg in args_with_derivatives] 390 assert len(diff_arg_names) > 0 391 392 # Do replacement of input variables 393 new_args = [] 394 for arg_name in all_arg_names: 395 if arg_name in diff_arg_names: 396 arg_name = arg_name + "_t" 397 new_args.append(arg_name) 398 399 # TODO we are trolling 400 if f.func.has_symint(): 401 defn_name += "_symint" 402 403 # Call into the forward again. We need two cases here to handle both Tensor methods and at:: functions. 404 if Variant.function in f.variants: 405 fw_formula = f"at::{defn_name}({', '.join(new_args)})" 406 else: 407 assert Variant.method in f.variants 408 fw_formula = f"{new_args[0]}.{defn_name}({', '.join(new_args[1:])})" 409 410 # All of the input tangents are always used so all of them are required here. 411 required_inputs_tangent = tuple(diff_arg_names) 412 formula = fw_formula 413 414 # At this point, the formula is final and is not modified anymore. 415 416 # During forward formula, we use the primal instead of the input Tensors. 417 # This call inspects the formula to find for which input's primal are used. 418 required_inputs_primal = find_required_inputs(formula, "_p") 419 420 updated_derivatives.append( 421 ForwardDerivative( 422 formula=formula, 423 var_names=defn.var_names, 424 var_types=defn.var_types, 425 required_inputs_fw_grad=required_inputs_tangent, 426 required_inputs_primal=required_inputs_primal, 427 required_original_self_value=False, 428 is_reusing_outplace_formula=False, 429 ) 430 ) 431 432 return updated_derivatives 433 434 435def is_forward_derivative_definition( 436 all_arg_names: list[str], names: tuple[str, ...] 437) -> bool: 438 for name in names: 439 return name not in all_arg_names 440 raise RuntimeError("Expected `names` to be non-empty") 441 442 443def create_differentiability_info( 444 defn_dict: dict[Any, Any], 445 functions_by_signature: dict[FunctionSchema, list[NativeFunction]], 446 functions_by_schema: dict[str, NativeFunction], 447 op_counter: Counter[str], 448 used_dispatch_keys: set[str], 449) -> tuple[FunctionSchema, dict[str, DifferentiabilityInfo]]: 450 """Processes a single entry `defn` in derivatives.yaml""" 451 452 def canonical_function( 453 functions: Sequence[NativeFunction], name: str 454 ) -> NativeFunction: 455 for f in functions: 456 if ( 457 not f.func.is_functional_fn() 458 and not f.func.is_out_fn() 459 and name == str(f.func.name.name) 460 ): 461 return f 462 # some functions only have in-place variants 463 assert name + "_" == cpp.name(functions[0].func) 464 return functions[0] 465 466 def split_names(raw_names: str) -> tuple[str, ...]: 467 """Given "foo, bar", return ["foo", "bar"].""" 468 return tuple(x.strip() for x in raw_names.split(",")) 469 470 def check_grad_usage(defn_name: str, derivatives: Sequence[Derivative]) -> None: 471 """ 472 Check for some subtle mistakes one might make when writing derivatives. 473 These mistakes will compile, but will be latent until a function is 474 used with double backwards. 475 """ 476 477 uses_grad = False # true if any derivative uses "grad" 478 num_grads_uses = 0 # count of uses of "grads" or "grads[INDEX]" 479 uses_named_grads = False # true if any derivative uses "grad_{name}" 480 used_grads_indices: list[int] = [] # which indices of grads are used 481 for d in derivatives: 482 formula = d.formula 483 uses_grad = uses_grad or bool( 484 re.findall(IDENT_REGEX.format("grad"), formula) 485 ) 486 num_grads_uses += len(re.findall(IDENT_REGEX.format("grads"), formula)) 487 uses_named_grads = uses_named_grads or bool(d.named_gradients) 488 used_grads_indices.extend(used_gradient_indices(formula)) 489 # This is a basic sanity check: the number of places we see 490 # "grads" should be no fewer than the number of indices we see 491 # inside "grads". They may not be equal because we may use 492 # "grads" without an index. 493 assert num_grads_uses >= len(used_grads_indices) 494 # Thus if the number is equal, every use of grads is also 495 # indexed. 496 only_used_grads_indices = num_grads_uses == len(used_grads_indices) 497 498 if uses_grad and num_grads_uses > 0: 499 raise RuntimeError( 500 f"Derivative definition of {defn_name} in derivatives.yaml illegally " 501 "mixes use of 'grad' and 'grads'. Consider replacing " 502 "occurrences of 'grad' with 'grads[0]'" 503 ) 504 505 if only_used_grads_indices and set(used_grads_indices) == {0}: 506 raise RuntimeError( 507 f"Derivative definition of {defn_name} in derivatives.yaml solely " 508 "refers to 'grads[0]'. If the first output is indeed the " 509 "only differentiable output, replace 'grads[0]' with 'grad'; " 510 "otherwise, there is a likely error in your derivatives " 511 "declaration." 512 ) 513 514 if uses_named_grads and (uses_grad or num_grads_uses > 0): 515 raise RuntimeError( 516 f"Derivative definition of {defn_name} in derivatives.yaml illegally " 517 'mixes use of "grad_RETURN_NAME" and "grad" or "grads[x]". Use ' 518 "only one method for identifying gradients." 519 ) 520 521 @with_native_function 522 def set_up_derivatives( 523 f: NativeFunction, 524 ) -> tuple[ 525 Sequence[Derivative], 526 Sequence[ForwardDerivative], 527 Sequence[Binding], 528 Sequence[str], 529 Sequence[str], 530 ]: 531 # Set up the derivative information 532 derivatives: list[Derivative] = [] 533 forward_derivatives: list[ForwardDerivative] = [] 534 non_differentiable_arg_names: list[str] = [] 535 args_with_derivatives_set: set[str] = set() 536 537 all_arg_names = [a.name for a in cpp_arguments(f)] 538 all_ret_names = [ 539 r.name for r in f.func.returns 540 ] # only used for the assert below 541 # output_differentiability is captured from the enclosed 542 # scope. Don't modify it. 543 # 544 # If it is not present, then no output is explicitly 545 # undifferentiable. 546 # 547 # It may be present and shorter than the length of return 548 # values. If that's the case, any return value that does not 549 # have a corresponding entry is considered not differentiable. 550 differentiability = output_differentiability or [True] * len(f.func.returns) 551 # A return is available as a named gradient ... 552 available_named_gradients = [ 553 f"grad_{ret.name}" 554 for ret, differentiable in zip(f.func.returns, differentiability) 555 # if it has not been explicitly made undifferentiable 556 if differentiable 557 # and if it has a name 558 and ret.name is not None 559 # and if its type is differentiable 560 and ret.type.is_tensor_like() 561 ] 562 563 for raw_names in sorted(defn.keys()): 564 formula = defn[raw_names] 565 names = split_names(raw_names) 566 567 for name in names: 568 assert not (name in all_arg_names and name in all_ret_names), ( 569 f"While processing the derivative formula for '{f.func.name}' wrt '{name}', " 570 f"expected '{name}' to not be both an input arg and named return. " 571 ) 572 573 if is_forward_derivative_definition(all_arg_names, names): 574 forward_derivatives.append(create_forward_derivative(f, formula, names)) 575 else: 576 if formula.lower().strip() == "non_differentiable": 577 non_differentiable_arg_names += names 578 else: 579 derivative = create_derivative( 580 f, formula, names, available_named_gradients 581 ) 582 derivatives.append(derivative) 583 args_with_derivatives_set |= set(names) 584 585 overlap = args_with_derivatives_set.intersection(non_differentiable_arg_names) 586 if overlap: 587 raise RuntimeError( 588 f"derivatives definition for {defn} have overlapped non_differentiable " 589 f"and differentiable variables: {overlap}" 590 ) 591 592 # Next, let us determine the list of inputs in order. 593 # TODO: do we need eagerly calculate and save it here? Can it be derived 594 # from NativeFunction and `derivatives` on callsites instead? 595 args_with_derivatives = [ 596 a for a in cpp_arguments(f) if a.name in args_with_derivatives_set 597 ] 598 599 # Postprocess forward derivatives definitions now that we know the differentiable arguments 600 forward_derivatives = postprocess_forward_derivatives( 601 f, 602 defn_name, 603 all_arg_names, 604 derivatives, 605 forward_derivatives, 606 args_with_derivatives, 607 ) 608 609 # Test to see if the use of 'grads' makes sense. 610 check_grad_usage(defn_name, derivatives) 611 612 return ( 613 derivatives, 614 forward_derivatives, 615 args_with_derivatives, 616 non_differentiable_arg_names, 617 available_named_gradients, 618 ) 619 620 # NB: Removes 'name' from defn dictionary 621 specification = defn_dict.pop("name") 622 defn_name, _ = split_name_params(specification) 623 # NB: Removes 'output_differentiability' from defn dictionary 624 # `None` means all differentiable. 625 output_differentiability = defn_dict.pop("output_differentiability", None) 626 output_differentiability_conditions = None 627 if output_differentiability and any( 628 isinstance(diff, str) for diff in output_differentiability 629 ): 630 if len(output_differentiability) != 1: 631 raise RuntimeError( 632 f"Not supported: for {specification}," 633 f"output_differentiability must either be " 634 f"List[bool] or a List[str] where each str is a " 635 f"condition. In the case where it is a condition, " 636 f"we only support single-output functions. " 637 f"Please file us an issue. " 638 ) 639 output_differentiability_conditions = output_differentiability 640 output_differentiability = [True] 641 642 schema_function = functions_by_schema.get(specification) 643 if not schema_function: 644 avail = "\n".join( 645 k for k, v in functions_by_schema.items() if cpp.name(v.func) == defn_name 646 ) 647 raise RuntimeError( 648 f"could not find ATen function for schema: {specification} " 649 f". Available signatures:\n{avail}" 650 ) 651 652 # now map this to the legacy schema; this isn't technically necessary, but we'd need some logic here 653 # to map in-place schemas to the out-of-place variants. 654 # TODO: maybe the logic to handle the legacy schema is no longer necessary? 655 signature = schema_function.func.signature() 656 functions = functions_by_signature[signature] 657 if len(functions) == 0: 658 avail = "\n".join( 659 str(k) 660 for k, v in functions_by_signature.items() 661 if cpp.name(k) == defn_name 662 ) 663 raise RuntimeError( 664 f"could not find ATen function for legacy signature: {signature} " 665 f"corresponding to schema {specification}. Please report a bug to PyTorch. " 666 f"Available signatures:\n{avail}" 667 ) 668 669 canonical = canonical_function(functions, defn_name) 670 if "grad_input_mask" in (a.name for a in cpp_arguments(canonical)): 671 raise RuntimeError( 672 f"Schema for {defn_name} has an argument named grad_input_mask, " 673 "but this name would be shadowed by our codegen. " 674 "Please use a different name in native_functions.yaml." 675 ) 676 677 if "result" in (a.name for a in cpp_arguments(canonical)): 678 raise RuntimeError( 679 f"Schema for {defn_name} has an argument named result, " 680 "but this is only allowed for outputs." 681 "Please use a different name in native_functions.yaml." 682 ) 683 684 diffinfo_dict = {} 685 for key, defn in defn_dict["dispatch"].items(): 686 if key != "Default" and key not in _VALID_AUTOGRAD_KEYS: 687 raise RuntimeError( 688 f"Invalid dispatch key {key} in derivatives.yaml for {specification}," 689 f" expected key to be one of {_VALID_AUTOGRAD_KEYS}" 690 ) 691 if key not in used_dispatch_keys: 692 used_dispatch_keys.add(key) 693 694 ( 695 derivatives, 696 forward_derivatives, 697 args_with_derivatives, 698 non_differentiable_arg_names, 699 available_named_gradients, 700 ) = set_up_derivatives(canonical) 701 702 used_named_gradients: set[str] = set() 703 for d in derivatives: 704 used_named_gradients |= d.named_gradients 705 706 # only assign an op name if we are actually going to calculate a derivative 707 op = None 708 if args_with_derivatives: 709 op_prefix = _create_op_prefix(defn_name) 710 if key != "Default": 711 op_prefix = op_prefix + key 712 op = f"{op_prefix}{op_counter[op_prefix]}" 713 op_counter[op_prefix] += 1 714 715 diffinfo_dict[key] = DifferentiabilityInfo( 716 name=defn_name, 717 func=canonical, 718 op=op, 719 derivatives=derivatives, 720 forward_derivatives=forward_derivatives, 721 all_saved_inputs=dedup_vars( 722 [v for d in derivatives for v in d.saved_inputs] 723 ), 724 all_saved_outputs=dedup_vars( 725 [v for d in derivatives for v in d.saved_outputs] 726 ), 727 available_named_gradients=available_named_gradients, 728 used_named_gradients=used_named_gradients, 729 args_with_derivatives=args_with_derivatives, 730 non_differentiable_arg_names=non_differentiable_arg_names, 731 output_differentiability=output_differentiability, 732 output_differentiability_conditions=output_differentiability_conditions, 733 ) 734 735 return canonical.func, diffinfo_dict 736 737 738GRAD_INDEX_REGEX = r"(?:^|\W)grads\[(\d+)\]" 739 740 741def used_gradient_indices(formula: str) -> list[int]: 742 """Determine a list of gradient indices (the i in grads[i]) that 743 are used by the formula. 744 745 >>> used_gradient_indices("foo(grads[0], grads[1])") 746 [0, 1] 747 """ 748 return [int(i) for i in re.findall(GRAD_INDEX_REGEX, formula)] 749 750 751def saved_variables( 752 formula: str, 753 nctypes: list[NamedCType], 754 var_names: tuple[str, ...], 755) -> tuple[str, tuple[SavedAttribute, ...]]: 756 def stride_expr(name: str) -> str: 757 assert var_names == (name,), ( 758 'Replacement for ".strides()" is currently only supported for single derivatives of the same tensor ' 759 'that ".strides()" is being called on.' 760 ) 761 return f'strides_or_error({name}, "{name}")' 762 763 REPLACEMENTS: list[tuple[str, dict[str, Any]]] = [ 764 # replace self.sym_sizes() with self_sym_sizes 765 ( 766 r"{}.sym_sizes\(\)", 767 { 768 "suffix": "_sym_sizes", 769 "nctype": lambda name: NamedCType(name, BaseCType(symIntArrayRefT)), 770 }, 771 ), 772 # replace self->sym_sizes() with self_sym_sizes_opt 773 ( 774 r"{}->sym_sizes\(\)", 775 { 776 "suffix": "_sym_sizes_opt", 777 "nctype": lambda name: NamedCType( 778 name, OptionalCType(BaseCType(symIntArrayRefT)) 779 ), 780 "expr": lambda name: f"{name}.has_value() ? std::optional<c10::SymIntArrayRef>({name}->sym_sizes()) : std::nullopt", 781 }, 782 ), 783 # replace self.sym_blocksize() with self_sym_blocksize_opt 784 ( 785 r"{}.sym_blocksize\(\)", 786 { 787 "suffix": "_self_sym_blocksize_opt", 788 "nctype": lambda name: NamedCType( 789 name, OptionalCType(BaseCType(symIntArrayRefT)) 790 ), 791 "expr": lambda name: f"at::sparse_csr::getSymIntBlockSize({name})", 792 }, 793 ), 794 # replace self.options() with self_options 795 ( 796 r"{}.options\(\)", 797 { 798 "suffix": "_options", 799 "nctype": lambda name: NamedCType(name, BaseCType(tensorOptionsT)), 800 }, 801 ), 802 # replace zeros_like(self) with self_info 803 ( 804 r"zeros_like\({}\)", 805 { 806 "suffix": "_info", 807 "nctype": lambda name: NamedCType(name, BaseCType(typeAndSizeT)), 808 "expr": lambda name: name, # at save-time 809 "res": lambda name: name + "_info.zeros()", # at eval-time 810 }, 811 ), 812 # replace self.sym_size(2) with self_sym_size_2 813 ( 814 r"{}.sym_size\((-?\w+)\)", 815 { 816 "suffix": lambda m: f"_sym_argsize_{m.groups()[0].replace('-', 'minus_')}", 817 "nctype": lambda name: NamedCType(name, BaseCType(SymIntT)), 818 }, 819 ), 820 # replace self.numel() with self_numel 821 ( 822 r"{}.numel\(\)", 823 { 824 "suffix": "_numel", 825 "nctype": lambda name: NamedCType(name, BaseCType(longT)), 826 }, 827 ), 828 # replace self.sym_numel() with self_sym_numel 829 ( 830 r"{}.sym_numel\(\)", 831 { 832 "suffix": "_sym_numel", 833 "nctype": lambda name: NamedCType(name, BaseCType(SymIntT)), 834 }, 835 ), 836 # replace to_args_sizes(self) with self_args_sizes 837 ( 838 r"to_args_sizes\({}\)", 839 { 840 "suffix": "_args_sizes", 841 "nctype": lambda name: NamedCType( 842 name, VectorCType(VectorCType(BaseCType(longT))) 843 ), 844 }, 845 ), 846 # replace to_args_sizes_symint(self) with self_args_sizes 847 ( 848 r"to_args_sizes_symint\({}\)", 849 { 850 "suffix": "_args_sizes_symint", 851 "nctype": lambda name: NamedCType( 852 name, VectorCType(VectorCType(BaseCType(SymIntT))) 853 ), 854 }, 855 ), 856 # replace to_args_scalartypes(self) with self_args_scalartypes 857 ( 858 r"to_args_scalartypes\({}\)", 859 { 860 "suffix": "_args_scalartypes", 861 "nctype": lambda name: NamedCType( 862 name, VectorCType(BaseCType(scalarTypeT)) 863 ), 864 }, 865 ), 866 # replace TensorGeometry(self) with self_geometry 867 ( 868 r"TensorGeometry\({}\)", 869 { 870 "suffix": "_geometry", 871 "nctype": lambda name: NamedCType(name, BaseCType(tensorGeometryT)), 872 }, 873 ), 874 ( 875 r"{}.scalar_type\(\)", 876 { 877 "suffix": "_scalar_type", 878 "nctype": lambda name: NamedCType(name, BaseCType(scalarTypeT)), 879 }, 880 ), 881 # replace self.dim() with self_dim 882 ( 883 r"{}.dim\(\)", 884 { 885 "suffix": "_dim", 886 "nctype": lambda name: NamedCType(name, BaseCType(longT)), 887 }, 888 ), 889 # replace self.sym_strides() with self_sym_strides 890 ( 891 r"{}.sym_strides\(\)", 892 { 893 "suffix": "_sym_strides", 894 "nctype": lambda name: NamedCType(name, BaseCType(symIntArrayRefT)), 895 "expr": stride_expr, 896 }, 897 ), 898 # replace self.layout() with self_layout 899 ( 900 r"{}.layout\(\)", 901 { 902 "suffix": "_layout", 903 "nctype": lambda name: NamedCType(name, BaseCType(layoutT)), 904 }, 905 ), 906 # replace self.is_conj() with self_conjugate 907 ( 908 r"{}.is_conj\(\)", 909 { 910 "suffix": "_conjugate", 911 "nctype": lambda name: NamedCType(name, BaseCType(boolT)), 912 }, 913 ), 914 ] 915 916 # find which arguments need to be saved 917 saved: list[SavedAttribute] = [] 918 919 if ".sizes()" in formula or "->sizes()" in formula: 920 raise RuntimeError( 921 ".sizes() is not supported in derivative formulas. Instead, please use the SymInt version," 922 + f".sym_sizes(), which returned a c10::SymIntArrayRef. formula={formula}" 923 ) 924 if re.search(r"\.size\([-]?\d+\)", formula) or re.search( 925 r"->size\([-]?\d+\)", formula 926 ): 927 raise RuntimeError( 928 ".size(int) is not supported in derivative formulas. Instead, please use the SymInt version," 929 + f".sym_size(int), which returned a c10::SymIntArrayRef. formula={formula}" 930 ) 931 if ".strides()" in formula or "->strides()" in formula: 932 raise RuntimeError( 933 ".strides() is not supported in derivative formulas. Instead, please use the SymInt version," 934 + f".sym_strides(), which returned a c10::SymIntArrayRef. formula={formula}" 935 ) 936 for nctype in nctypes: 937 name = ( 938 nctype.name.name if isinstance(nctype.name, SpecialArgName) else nctype.name 939 ) 940 # First search the formula for expressions which can be evaluated 941 # when the autograd Function is created to avoid saving variables 942 for regex, info in REPLACEMENTS: 943 944 def repl(m: re.Match[str]) -> str: 945 suffix: str = ( 946 info["suffix"](m) if callable(info["suffix"]) else info["suffix"] 947 ) 948 expr: str = info["expr"](name) if "expr" in info else m.group(0) 949 saved.append( 950 SavedAttribute( 951 nctype=info["nctype"](name + suffix), 952 expr=expr, 953 ) 954 ) 955 if "res" in info: 956 replacement: str = info["res"](name) 957 return replacement 958 return name + suffix 959 960 formula = re.sub(regex.format(name), repl, formula) 961 962 # std::optional<std::string> types stored in Backward nodes must be 963 # converted to std::optional<std::string_view> before being passed into 964 # the backward function 965 if nctype.type == OptionalCType(BaseCType(stringT)): 966 formula = re.sub( 967 rf"\b{name}\b", 968 f"{name}.has_value() ? std::optional<c10::string_view>({name}.value()) : std::nullopt", 969 formula, 970 ) 971 972 # Find any variables which remain in the formula and save them 973 if re.search(IDENT_REGEX.format(name), formula): 974 saved.append( 975 SavedAttribute( 976 nctype=nctype, 977 expr=name, 978 ) 979 ) 980 981 return formula, tuple(saved) 982 983 984def _create_op_prefix(name: str) -> str: 985 """Takes a native function name converts to a op prefix name. 986 987 Note that the "name" parameter must be the native function name 988 without the optional variant suffix, so "add" instead of 989 "add.out". 990 991 OP names correspond to classes, hence the change to title case. 992 993 Example:: 994 >>> _create_op_prefix('add') 995 'AddBackward' 996 """ 997 camel_case = "".join([p.title() for p in name.split("_")]) 998 return (camel_case + "Backward").replace("ForwardBackward", "Backward") 999 1000 1001def dedup_vars(vars: Sequence[SavedAttribute]) -> Sequence[SavedAttribute]: 1002 seen: set[str] = set() 1003 saved: list[SavedAttribute] = [] 1004 for var in vars: 1005 name = ( 1006 var.nctype.name.name 1007 if isinstance(var.nctype.name, SpecialArgName) 1008 else var.nctype.name 1009 ) 1010 if name in seen: 1011 continue 1012 seen.add(name) 1013 saved.append(var) 1014 return saved 1015