xref: /aosp_15_r20/external/pytorch/tools/autograd/load_derivatives.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Parses derivatives.yaml into autograd functions
2#
3# Each autograd function is represented by `DifferentiabilityInfo` containing
4# a list of `Derivative`. See `torchgen.api.autograd` for the data models.
5
6from __future__ import annotations
7
8import re
9from collections import defaultdict
10from typing import Any, Counter, Dict, Sequence, Set, Tuple
11
12import yaml
13
14from torchgen.api import cpp
15from torchgen.api.autograd import (
16    Derivative,
17    DifferentiabilityInfo,
18    ForwardDerivative,
19    SavedAttribute,
20)
21from torchgen.api.types import (
22    BaseCType,
23    Binding,
24    boolT,
25    CppSignatureGroup,
26    layoutT,
27    longT,
28    NamedCType,
29    OptionalCType,
30    scalarTypeT,
31    SpecialArgName,
32    stringT,
33    symIntArrayRefT,
34    SymIntT,
35    tensorGeometryT,
36    tensorOptionsT,
37    typeAndSizeT,
38    VectorCType,
39)
40from torchgen.context import with_native_function
41from torchgen.gen import get_grouped_by_view_native_functions, parse_native_yaml
42from torchgen.model import (
43    AUTOGRAD_KEYS,
44    FunctionSchema,
45    NativeFunction,
46    NativeFunctionsViewGroup,
47    OperatorName,
48    SchemaKind,
49    Type,
50    Variant,
51)
52from torchgen.utils import concatMap, IDENT_REGEX, split_name_params
53from torchgen.yaml_utils import YamlLoader
54
55
56DerivativeRet = Tuple[Dict[FunctionSchema, Dict[str, DifferentiabilityInfo]], Set[str]]
57
58_GLOBAL_LOAD_DERIVATIVE_CACHE: dict[tuple[str, str], DerivativeRet] = {}
59
60_VALID_AUTOGRAD_KEYS = set(AUTOGRAD_KEYS)
61
62
63# This function directly adds per-dispatchkey derivative entries for {view}_copy variants of each view op.
64# Since every {view} and {view}_copy op shares the same derivative formula,
65# we generate them here instead of duplicating them in the yaml.
66# See Note [Codegen'd {view}_copy Operators]
67def add_view_copy_derivatives(
68    infos: dict[FunctionSchema, dict[str, DifferentiabilityInfo]],
69    view_groups: list[NativeFunctionsViewGroup],
70) -> None:
71    # Get the map from each view op's name to its corresponding view group
72    view_name_to_group: dict[OperatorName, NativeFunctionsViewGroup] = {
73        g.view.func.name: g for g in view_groups
74    }
75
76    view_infos = {}
77
78    for info_dispatch_dict in infos.values():
79        # maybe_view_group only needs to be calculated once per info_dispatch_dict
80        maybe_view_group = None
81        view_copy_differentiability_infos = {}
82        for dispatch_key, info in info_dispatch_dict.items():
83            maybe_view_group = view_name_to_group.get(info.func.func.name, None)
84            if maybe_view_group is not None and maybe_view_group.view_copy is not None:
85                view_copy_info = info.create_view_copy_from_view_derivative(
86                    maybe_view_group
87                )
88                if view_copy_info is not None:
89                    fn_schema = view_copy_info.func.func
90                    view_copy_differentiability_infos[dispatch_key] = view_copy_info
91            else:
92                break
93        # prefer manually-defined derivatives if any
94        if len(view_copy_differentiability_infos) > 0 and fn_schema not in infos:
95            assert fn_schema is not None
96            view_infos[fn_schema] = view_copy_differentiability_infos
97
98    infos.update(view_infos)
99
100
101def load_derivatives(
102    derivatives_yaml_path: str, native_yaml_path: str, tags_yaml_path: str
103) -> DerivativeRet:
104    # Do some caching as this is a deterministic function
105    global _GLOBAL_LOAD_DERIVATIVE_CACHE
106    key = (derivatives_yaml_path, native_yaml_path)
107    if key not in _GLOBAL_LOAD_DERIVATIVE_CACHE:
108        with open(derivatives_yaml_path) as f:
109            definitions = yaml.load(f, Loader=YamlLoader)
110
111        funcs = parse_native_yaml(native_yaml_path, tags_yaml_path).native_functions
112        # From the parsed native functions, separate out the (generated) view_copy functions,
113        # so we can generate derivatives for them separately.
114        native_functions_with_view_groups = get_grouped_by_view_native_functions(funcs)
115        native_functions = concatMap(
116            lambda g: [g]
117            if isinstance(g, NativeFunction)
118            else list(g.functions(include_copy=True)),
119            native_functions_with_view_groups,
120        )
121        view_groups = [
122            g
123            for g in native_functions_with_view_groups
124            if isinstance(g, NativeFunctionsViewGroup)
125        ]
126
127        # What's the difference between function schema v.s. signature?
128        # function schema is the complete declaration including mutability annotation / default value and etc.
129        # signature is the canonical schema for a group of functions (in-place/out/functional variants)
130        # that are semantically related.
131        functions_by_signature: dict[
132            FunctionSchema, list[NativeFunction]
133        ] = defaultdict(list)
134        functions_by_schema: dict[str, NativeFunction] = {}
135        for function in native_functions:
136            functions_by_signature[function.func.signature()].append(function)
137            assert str(function.func) not in functions_by_schema
138            functions_by_schema[str(function.func)] = function
139
140        # Keep track of how many of which ops we've seen so we can
141        # disambiguate them with a numeric suffix.
142        op_counter = Counter[str]()
143
144        # infos is a dict that maps FunctionSchema -> a dict of per dispatch key DifferentiabilityInfos
145        # this is useful because in tools/autograd/gen_autograd.py:match_differentiability_info
146        # we ultimately need to categorize the DifferentiabilityInfos by FunctionSchema
147        infos: dict[FunctionSchema, dict[str, DifferentiabilityInfo]] = {}
148        used_dispatch_keys: set[str] = set()
149        for defn_dict in definitions:
150            # Ensure that the old derivatives.yaml schema with no dispatch key can be loaded.
151            if "dispatch" not in defn_dict:
152                specification = defn_dict.pop("name")
153                output_differentiability = defn_dict.pop(
154                    "output_differentiability", None
155                )
156                defn_dict = {"name": specification, "dispatch": {"Default": defn_dict}}
157                if output_differentiability:
158                    defn_dict["output_differentiability"] = output_differentiability
159            name, per_dispatch_diffinfos = create_differentiability_info(
160                defn_dict,
161                functions_by_signature,
162                functions_by_schema,
163                op_counter,
164                used_dispatch_keys,
165            )
166            infos[name] = per_dispatch_diffinfos
167
168        add_view_copy_derivatives(infos, view_groups)
169
170        # cache both loaded infos as well a a set of all the dispatch_keys/aliases
171        # that appear in derivatives.yaml. used_dispatch_keys is useful for generating
172        # VariableType.cpp where we need a TORCH_LIBRARY_IMPL for every autograd dispatch key used
173        _GLOBAL_LOAD_DERIVATIVE_CACHE[key] = infos, used_dispatch_keys
174
175    return _GLOBAL_LOAD_DERIVATIVE_CACHE[key]
176
177
178# TODO: Why is this going through CppSignatureGroup, that doesn't make sense...
179@with_native_function
180def cpp_arguments(f: NativeFunction) -> Sequence[Binding]:
181    sigs = CppSignatureGroup.from_native_function(f, method=False)
182    if sigs.symint_signature is not None:
183        return sigs.symint_signature.arguments()
184    else:
185        return sigs.signature.arguments()
186
187
188def create_derivative(
189    f: NativeFunction,
190    formula: str,
191    var_names: tuple[str, ...],
192    available_named_gradients: Sequence[str],
193) -> Derivative:
194    original_formula = formula
195    arguments: list[NamedCType] = [
196        a.nctype.remove_const_ref() for a in cpp_arguments(f)
197    ]
198
199    return_names = tuple(n if n != "self" else "result" for n in cpp.return_names(f))
200    return_types = tuple(
201        cpp.return_type(r, symint=True).remove_const_ref() for r in f.func.returns
202    )
203
204    named_returns = [
205        NamedCType(name, type) for name, type in zip(return_names, return_types)
206    ]
207
208    formula, saved_inputs = saved_variables(formula, arguments, var_names)
209    formula, saved_outputs = saved_variables(formula, named_returns, var_names)
210
211    used_named_gradients = {
212        name
213        for name in available_named_gradients
214        if re.search(IDENT_REGEX.format(name), formula)
215    }
216
217    # Check that the referenced derivatives in the formula are in bounds
218    for i in used_gradient_indices(formula):
219        if i >= len(f.func.returns):
220            raise RuntimeError(
221                f"Out of bounds grads access: derivative formula for {cpp.name(f.func)} "
222                f"used grads[{i}], but the forward only returns {len(f.func.returns)} outputs."
223            )
224
225    return Derivative(
226        formula=formula,
227        original_formula=original_formula,
228        var_names=var_names,
229        saved_inputs=saved_inputs,
230        saved_outputs=saved_outputs,
231        named_gradients=used_named_gradients,
232    )
233
234
235def create_forward_derivative(
236    f: NativeFunction, formula: str, names: tuple[str, ...]
237) -> ForwardDerivative:
238    var_names = names
239    var_types: tuple[Type, ...] | None = None
240    for r in f.func.returns:
241        if r.name in var_names:
242            if var_types is None:
243                var_types = ()
244            var_types = var_types + (r.type,)
245
246    # Handle default return names
247    if var_types is None:
248        if var_names == ("result",):
249            assert len(f.func.returns) == 1
250            var_types = (f.func.returns[0].type,)
251        else:
252            for var_name in var_names:
253                res = re.findall(r"^result(\d+)$", var_name)
254                if len(res) == 1:
255                    if var_types is None:
256                        var_types = ()
257                    arg_idx = int(res[0])
258                    var_types = var_types + (f.func.returns[arg_idx].type,)
259
260    assert var_types is not None, "No matching output for forward derivative definition"
261    return ForwardDerivative(
262        formula=formula,
263        var_names=var_names,
264        var_types=var_types,
265        required_inputs_fw_grad=None,
266        required_inputs_primal=None,
267        required_original_self_value=False,
268        is_reusing_outplace_formula=False,
269    )
270
271
272def postprocess_forward_derivatives(
273    f: NativeFunction,
274    defn_name: str,
275    all_arg_names: list[str],
276    derivatives: list[Derivative],
277    forward_derivatives: list[ForwardDerivative],
278    args_with_derivatives: Sequence[Binding],
279) -> list[ForwardDerivative]:
280    def find_required_inputs(formula: str, postfix: str) -> tuple[str, ...]:
281        is_foreach = f.func.name.name.base.startswith("_foreach_")
282        required_inputs = set()
283        for arg in args_with_derivatives:
284            if (
285                arg.type in ("at::TensorList", "const at::ITensorListRef &")
286                and not is_foreach
287            ):
288                # The functions taking TensorList handle everything internally
289                continue
290            arg_name = arg.name
291
292            found = re.search(IDENT_REGEX.format(arg_name), formula)
293            if found:
294                raise RuntimeError(
295                    f"The forward formula for {defn_name} is using the base name of the {arg_name} "
296                    f"argument which is ambiguous. You should use {arg_name}_p to access the primal "
297                    f"value and {arg_name}_t to access the tangent."
298                )
299
300            found = re.search(IDENT_REGEX.format(arg_name + postfix), formula)
301            if found:
302                required_inputs.add(arg_name)
303
304        return tuple(required_inputs)
305
306    updated_derivatives: list[ForwardDerivative] = []
307
308    for defn in forward_derivatives:
309        formula = defn.formula
310        required_inputs_tangent = find_required_inputs(formula, "_t")
311        if formula == "auto_element_wise":
312            assert (
313                f.func.kind() != SchemaKind.inplace
314            ), f"Cannot use auto_element_wise with {f.func.name} because it is an in-place variant"
315            if (
316                (not len(args_with_derivatives) == 1)
317                or len(forward_derivatives) > 1
318                or len(forward_derivatives[0].var_names) > 1
319            ):
320                raise RuntimeError(
321                    f"Derivative definition of {defn_name} in derivatives.yaml defines the "
322                    "forward definition of gradient as element_wise but this only "
323                    "works for functions with a single differentiable input and a "
324                    "single differentiable output."
325                )
326            if not len(derivatives) == 1:
327                raise RuntimeError(
328                    f"Derivative definition of {defn_name} in derivatives.yaml defines the "
329                    "forward definition of gradient as element_wise but it does not "
330                    "defines the gradient formula for its argument which is required."
331                )
332            # This transformation is based on the observation that for element-wise functions, the Jacobian
333            # matrix is diagonal and thus doing J * v is the same as (v^T J)^T (in practice, we ignore the transpositions)
334            # For the complex case, we use hermitian transpose and get (v.conj() J).conj()
335            # So here we are going to re-use the backward formula and replace two things:
336            # 1) all occurrences of "grad" with "foo_t.conj()", where foo is the name of the unique differentiable input.
337            # 2) all usage of an original input "foo" with its primal value "foo_p".
338            # 3) conjugate the final result
339            # For example, for abs, the backward formula is:
340            #   grad * self.sgn()
341            # And this function generates a forward formula that is:
342            #   (self_t.conj() * self_p.sgn()).conj()
343
344            backward_formula = derivatives[0].original_formula
345            input_name = args_with_derivatives[0].name
346
347            # Do replacement 1) of the grad
348            def repl(m: Any) -> str:
349                return f"{m.group(1)}{input_name}_t.conj(){m.group(2)}"
350
351            fw_formula = re.sub(IDENT_REGEX.format("grad"), repl, backward_formula)
352
353            # Do replacement 2) of the input variables
354            for arg in args_with_derivatives:
355                arg_name = arg.name
356
357                def repl(m: Any) -> str:
358                    return f"{m.group(1)}{arg_name}_p{m.group(2)}"
359
360                fw_formula = re.sub(IDENT_REGEX.format(arg_name), repl, fw_formula)
361
362            # Do the final conjugate 3)
363            fw_formula = f"({fw_formula}).conj()"
364
365            # Since there is a single differentiable inputs and we necessarily need its tangent we can
366            # simply require all differentiable input's tangent.
367            required_inputs_tangent = tuple(all_arg_names)
368            formula = fw_formula
369        elif formula == "auto_linear":
370            if (
371                len(forward_derivatives) > 1
372                or len(forward_derivatives[0].var_names) > 1
373            ):
374                raise RuntimeError(
375                    f"Derivative definition of {defn_name} in derivatives.yaml defines the "
376                    "forward definition of gradient as linear but this only works "
377                    "for functions with a single differentiable output."
378                )
379            # This transformation is based on the observation that linear functions can be written as:
380            #   y = f(x) = A * x
381            # For some matrix A and the Jacobian of the function f is also A.
382            # So doing J * v = A * v = f(v).
383            # Hence to do the jvp, we simply need to evaluate the function at the point v instead of x.
384            # We do this by calling the forward again by replacing any occurrence of the differentiable
385            # input "foo" by it's tangent "foo_t".
386            # Note that multiple inputs are not a problem as long as the function is truly linear wrt to
387            # the vector where all the differentiable inputs are stacked.
388
389            diff_arg_names = [arg.name for arg in args_with_derivatives]
390            assert len(diff_arg_names) > 0
391
392            # Do replacement of input variables
393            new_args = []
394            for arg_name in all_arg_names:
395                if arg_name in diff_arg_names:
396                    arg_name = arg_name + "_t"
397                new_args.append(arg_name)
398
399            # TODO we are trolling
400            if f.func.has_symint():
401                defn_name += "_symint"
402
403            # Call into the forward again. We need two cases here to handle both Tensor methods and at:: functions.
404            if Variant.function in f.variants:
405                fw_formula = f"at::{defn_name}({', '.join(new_args)})"
406            else:
407                assert Variant.method in f.variants
408                fw_formula = f"{new_args[0]}.{defn_name}({', '.join(new_args[1:])})"
409
410            # All of the input tangents are always used so all of them are required here.
411            required_inputs_tangent = tuple(diff_arg_names)
412            formula = fw_formula
413
414        # At this point, the formula is final and is not modified anymore.
415
416        # During forward formula, we use the primal instead of the input Tensors.
417        # This call inspects the formula to find for which input's primal are used.
418        required_inputs_primal = find_required_inputs(formula, "_p")
419
420        updated_derivatives.append(
421            ForwardDerivative(
422                formula=formula,
423                var_names=defn.var_names,
424                var_types=defn.var_types,
425                required_inputs_fw_grad=required_inputs_tangent,
426                required_inputs_primal=required_inputs_primal,
427                required_original_self_value=False,
428                is_reusing_outplace_formula=False,
429            )
430        )
431
432    return updated_derivatives
433
434
435def is_forward_derivative_definition(
436    all_arg_names: list[str], names: tuple[str, ...]
437) -> bool:
438    for name in names:
439        return name not in all_arg_names
440    raise RuntimeError("Expected `names` to be non-empty")
441
442
443def create_differentiability_info(
444    defn_dict: dict[Any, Any],
445    functions_by_signature: dict[FunctionSchema, list[NativeFunction]],
446    functions_by_schema: dict[str, NativeFunction],
447    op_counter: Counter[str],
448    used_dispatch_keys: set[str],
449) -> tuple[FunctionSchema, dict[str, DifferentiabilityInfo]]:
450    """Processes a single entry `defn` in derivatives.yaml"""
451
452    def canonical_function(
453        functions: Sequence[NativeFunction], name: str
454    ) -> NativeFunction:
455        for f in functions:
456            if (
457                not f.func.is_functional_fn()
458                and not f.func.is_out_fn()
459                and name == str(f.func.name.name)
460            ):
461                return f
462        # some functions only have in-place variants
463        assert name + "_" == cpp.name(functions[0].func)
464        return functions[0]
465
466    def split_names(raw_names: str) -> tuple[str, ...]:
467        """Given "foo, bar", return ["foo", "bar"]."""
468        return tuple(x.strip() for x in raw_names.split(","))
469
470    def check_grad_usage(defn_name: str, derivatives: Sequence[Derivative]) -> None:
471        """
472        Check for some subtle mistakes one might make when writing derivatives.
473        These mistakes will compile, but will be latent until a function is
474        used with double backwards.
475        """
476
477        uses_grad = False  # true if any derivative uses "grad"
478        num_grads_uses = 0  # count of uses of "grads" or "grads[INDEX]"
479        uses_named_grads = False  # true if any derivative uses "grad_{name}"
480        used_grads_indices: list[int] = []  # which indices of grads are used
481        for d in derivatives:
482            formula = d.formula
483            uses_grad = uses_grad or bool(
484                re.findall(IDENT_REGEX.format("grad"), formula)
485            )
486            num_grads_uses += len(re.findall(IDENT_REGEX.format("grads"), formula))
487            uses_named_grads = uses_named_grads or bool(d.named_gradients)
488            used_grads_indices.extend(used_gradient_indices(formula))
489        # This is a basic sanity check: the number of places we see
490        # "grads" should be no fewer than the number of indices we see
491        # inside "grads". They may not be equal because we may use
492        # "grads" without an index.
493        assert num_grads_uses >= len(used_grads_indices)
494        # Thus if the number is equal, every use of grads is also
495        # indexed.
496        only_used_grads_indices = num_grads_uses == len(used_grads_indices)
497
498        if uses_grad and num_grads_uses > 0:
499            raise RuntimeError(
500                f"Derivative definition of {defn_name} in derivatives.yaml illegally "
501                "mixes use of 'grad' and 'grads'. Consider replacing "
502                "occurrences of 'grad' with 'grads[0]'"
503            )
504
505        if only_used_grads_indices and set(used_grads_indices) == {0}:
506            raise RuntimeError(
507                f"Derivative definition of {defn_name} in derivatives.yaml solely "
508                "refers to 'grads[0]'.  If the first output is indeed the "
509                "only differentiable output, replace 'grads[0]' with 'grad'; "
510                "otherwise, there is a likely error in your derivatives "
511                "declaration."
512            )
513
514        if uses_named_grads and (uses_grad or num_grads_uses > 0):
515            raise RuntimeError(
516                f"Derivative definition of {defn_name} in derivatives.yaml illegally "
517                'mixes use of "grad_RETURN_NAME" and "grad" or "grads[x]". Use '
518                "only one method for identifying gradients."
519            )
520
521    @with_native_function
522    def set_up_derivatives(
523        f: NativeFunction,
524    ) -> tuple[
525        Sequence[Derivative],
526        Sequence[ForwardDerivative],
527        Sequence[Binding],
528        Sequence[str],
529        Sequence[str],
530    ]:
531        # Set up the derivative information
532        derivatives: list[Derivative] = []
533        forward_derivatives: list[ForwardDerivative] = []
534        non_differentiable_arg_names: list[str] = []
535        args_with_derivatives_set: set[str] = set()
536
537        all_arg_names = [a.name for a in cpp_arguments(f)]
538        all_ret_names = [
539            r.name for r in f.func.returns
540        ]  # only used for the assert below
541        # output_differentiability is captured from the enclosed
542        # scope. Don't modify it.
543        #
544        # If it is not present, then no output is explicitly
545        # undifferentiable.
546        #
547        # It may be present and shorter than the length of return
548        # values. If that's the case, any return value that does not
549        # have a corresponding entry is considered not differentiable.
550        differentiability = output_differentiability or [True] * len(f.func.returns)
551        # A return is available as a named gradient ...
552        available_named_gradients = [
553            f"grad_{ret.name}"
554            for ret, differentiable in zip(f.func.returns, differentiability)
555            # if it has not been explicitly made undifferentiable
556            if differentiable
557            # and if it has a name
558            and ret.name is not None
559            # and if its type is differentiable
560            and ret.type.is_tensor_like()
561        ]
562
563        for raw_names in sorted(defn.keys()):
564            formula = defn[raw_names]
565            names = split_names(raw_names)
566
567            for name in names:
568                assert not (name in all_arg_names and name in all_ret_names), (
569                    f"While processing the derivative formula for '{f.func.name}' wrt '{name}', "
570                    f"expected '{name}' to not be both an input arg and named return. "
571                )
572
573            if is_forward_derivative_definition(all_arg_names, names):
574                forward_derivatives.append(create_forward_derivative(f, formula, names))
575            else:
576                if formula.lower().strip() == "non_differentiable":
577                    non_differentiable_arg_names += names
578                else:
579                    derivative = create_derivative(
580                        f, formula, names, available_named_gradients
581                    )
582                    derivatives.append(derivative)
583                    args_with_derivatives_set |= set(names)
584
585        overlap = args_with_derivatives_set.intersection(non_differentiable_arg_names)
586        if overlap:
587            raise RuntimeError(
588                f"derivatives definition for {defn} have overlapped non_differentiable "
589                f"and differentiable variables: {overlap}"
590            )
591
592        # Next, let us determine the list of inputs in order.
593        # TODO: do we need eagerly calculate and save it here? Can it be derived
594        # from NativeFunction and `derivatives` on callsites instead?
595        args_with_derivatives = [
596            a for a in cpp_arguments(f) if a.name in args_with_derivatives_set
597        ]
598
599        # Postprocess forward derivatives definitions now that we know the differentiable arguments
600        forward_derivatives = postprocess_forward_derivatives(
601            f,
602            defn_name,
603            all_arg_names,
604            derivatives,
605            forward_derivatives,
606            args_with_derivatives,
607        )
608
609        # Test to see if the use of 'grads' makes sense.
610        check_grad_usage(defn_name, derivatives)
611
612        return (
613            derivatives,
614            forward_derivatives,
615            args_with_derivatives,
616            non_differentiable_arg_names,
617            available_named_gradients,
618        )
619
620    # NB: Removes 'name' from defn dictionary
621    specification = defn_dict.pop("name")
622    defn_name, _ = split_name_params(specification)
623    # NB: Removes 'output_differentiability' from defn dictionary
624    #     `None` means all differentiable.
625    output_differentiability = defn_dict.pop("output_differentiability", None)
626    output_differentiability_conditions = None
627    if output_differentiability and any(
628        isinstance(diff, str) for diff in output_differentiability
629    ):
630        if len(output_differentiability) != 1:
631            raise RuntimeError(
632                f"Not supported: for {specification},"
633                f"output_differentiability must either be "
634                f"List[bool] or a List[str] where each str is a "
635                f"condition. In the case where it is a condition, "
636                f"we only support single-output functions. "
637                f"Please file us an issue. "
638            )
639        output_differentiability_conditions = output_differentiability
640        output_differentiability = [True]
641
642    schema_function = functions_by_schema.get(specification)
643    if not schema_function:
644        avail = "\n".join(
645            k for k, v in functions_by_schema.items() if cpp.name(v.func) == defn_name
646        )
647        raise RuntimeError(
648            f"could not find ATen function for schema: {specification} "
649            f".  Available signatures:\n{avail}"
650        )
651
652    # now map this to the legacy schema; this isn't technically necessary, but we'd need some logic here
653    # to map in-place schemas to the out-of-place variants.
654    # TODO: maybe the logic to handle the legacy schema is no longer necessary?
655    signature = schema_function.func.signature()
656    functions = functions_by_signature[signature]
657    if len(functions) == 0:
658        avail = "\n".join(
659            str(k)
660            for k, v in functions_by_signature.items()
661            if cpp.name(k) == defn_name
662        )
663        raise RuntimeError(
664            f"could not find ATen function for legacy signature: {signature} "
665            f"corresponding to schema {specification}.  Please report a bug to PyTorch. "
666            f"Available signatures:\n{avail}"
667        )
668
669    canonical = canonical_function(functions, defn_name)
670    if "grad_input_mask" in (a.name for a in cpp_arguments(canonical)):
671        raise RuntimeError(
672            f"Schema for {defn_name} has an argument named grad_input_mask, "
673            "but this name would be shadowed by our codegen. "
674            "Please use a different name in native_functions.yaml."
675        )
676
677    if "result" in (a.name for a in cpp_arguments(canonical)):
678        raise RuntimeError(
679            f"Schema for {defn_name} has an argument named result, "
680            "but this is only allowed for outputs."
681            "Please use a different name in native_functions.yaml."
682        )
683
684    diffinfo_dict = {}
685    for key, defn in defn_dict["dispatch"].items():
686        if key != "Default" and key not in _VALID_AUTOGRAD_KEYS:
687            raise RuntimeError(
688                f"Invalid dispatch key {key} in derivatives.yaml for {specification},"
689                f" expected key to be one of {_VALID_AUTOGRAD_KEYS}"
690            )
691        if key not in used_dispatch_keys:
692            used_dispatch_keys.add(key)
693
694        (
695            derivatives,
696            forward_derivatives,
697            args_with_derivatives,
698            non_differentiable_arg_names,
699            available_named_gradients,
700        ) = set_up_derivatives(canonical)
701
702        used_named_gradients: set[str] = set()
703        for d in derivatives:
704            used_named_gradients |= d.named_gradients
705
706        # only assign an op name if we are actually going to calculate a derivative
707        op = None
708        if args_with_derivatives:
709            op_prefix = _create_op_prefix(defn_name)
710            if key != "Default":
711                op_prefix = op_prefix + key
712            op = f"{op_prefix}{op_counter[op_prefix]}"
713            op_counter[op_prefix] += 1
714
715        diffinfo_dict[key] = DifferentiabilityInfo(
716            name=defn_name,
717            func=canonical,
718            op=op,
719            derivatives=derivatives,
720            forward_derivatives=forward_derivatives,
721            all_saved_inputs=dedup_vars(
722                [v for d in derivatives for v in d.saved_inputs]
723            ),
724            all_saved_outputs=dedup_vars(
725                [v for d in derivatives for v in d.saved_outputs]
726            ),
727            available_named_gradients=available_named_gradients,
728            used_named_gradients=used_named_gradients,
729            args_with_derivatives=args_with_derivatives,
730            non_differentiable_arg_names=non_differentiable_arg_names,
731            output_differentiability=output_differentiability,
732            output_differentiability_conditions=output_differentiability_conditions,
733        )
734
735    return canonical.func, diffinfo_dict
736
737
738GRAD_INDEX_REGEX = r"(?:^|\W)grads\[(\d+)\]"
739
740
741def used_gradient_indices(formula: str) -> list[int]:
742    """Determine a list of gradient indices (the i in grads[i]) that
743    are used by the formula.
744
745    >>> used_gradient_indices("foo(grads[0], grads[1])")
746    [0, 1]
747    """
748    return [int(i) for i in re.findall(GRAD_INDEX_REGEX, formula)]
749
750
751def saved_variables(
752    formula: str,
753    nctypes: list[NamedCType],
754    var_names: tuple[str, ...],
755) -> tuple[str, tuple[SavedAttribute, ...]]:
756    def stride_expr(name: str) -> str:
757        assert var_names == (name,), (
758            'Replacement for ".strides()" is currently only supported for single derivatives of the same tensor '
759            'that ".strides()" is being called on.'
760        )
761        return f'strides_or_error({name}, "{name}")'
762
763    REPLACEMENTS: list[tuple[str, dict[str, Any]]] = [
764        # replace self.sym_sizes() with self_sym_sizes
765        (
766            r"{}.sym_sizes\(\)",
767            {
768                "suffix": "_sym_sizes",
769                "nctype": lambda name: NamedCType(name, BaseCType(symIntArrayRefT)),
770            },
771        ),
772        # replace self->sym_sizes() with self_sym_sizes_opt
773        (
774            r"{}->sym_sizes\(\)",
775            {
776                "suffix": "_sym_sizes_opt",
777                "nctype": lambda name: NamedCType(
778                    name, OptionalCType(BaseCType(symIntArrayRefT))
779                ),
780                "expr": lambda name: f"{name}.has_value() ? std::optional<c10::SymIntArrayRef>({name}->sym_sizes()) : std::nullopt",
781            },
782        ),
783        # replace self.sym_blocksize() with self_sym_blocksize_opt
784        (
785            r"{}.sym_blocksize\(\)",
786            {
787                "suffix": "_self_sym_blocksize_opt",
788                "nctype": lambda name: NamedCType(
789                    name, OptionalCType(BaseCType(symIntArrayRefT))
790                ),
791                "expr": lambda name: f"at::sparse_csr::getSymIntBlockSize({name})",
792            },
793        ),
794        # replace self.options() with self_options
795        (
796            r"{}.options\(\)",
797            {
798                "suffix": "_options",
799                "nctype": lambda name: NamedCType(name, BaseCType(tensorOptionsT)),
800            },
801        ),
802        # replace zeros_like(self) with self_info
803        (
804            r"zeros_like\({}\)",
805            {
806                "suffix": "_info",
807                "nctype": lambda name: NamedCType(name, BaseCType(typeAndSizeT)),
808                "expr": lambda name: name,  # at save-time
809                "res": lambda name: name + "_info.zeros()",  # at eval-time
810            },
811        ),
812        # replace self.sym_size(2) with self_sym_size_2
813        (
814            r"{}.sym_size\((-?\w+)\)",
815            {
816                "suffix": lambda m: f"_sym_argsize_{m.groups()[0].replace('-', 'minus_')}",
817                "nctype": lambda name: NamedCType(name, BaseCType(SymIntT)),
818            },
819        ),
820        # replace self.numel() with self_numel
821        (
822            r"{}.numel\(\)",
823            {
824                "suffix": "_numel",
825                "nctype": lambda name: NamedCType(name, BaseCType(longT)),
826            },
827        ),
828        # replace self.sym_numel() with self_sym_numel
829        (
830            r"{}.sym_numel\(\)",
831            {
832                "suffix": "_sym_numel",
833                "nctype": lambda name: NamedCType(name, BaseCType(SymIntT)),
834            },
835        ),
836        # replace to_args_sizes(self) with self_args_sizes
837        (
838            r"to_args_sizes\({}\)",
839            {
840                "suffix": "_args_sizes",
841                "nctype": lambda name: NamedCType(
842                    name, VectorCType(VectorCType(BaseCType(longT)))
843                ),
844            },
845        ),
846        # replace to_args_sizes_symint(self) with self_args_sizes
847        (
848            r"to_args_sizes_symint\({}\)",
849            {
850                "suffix": "_args_sizes_symint",
851                "nctype": lambda name: NamedCType(
852                    name, VectorCType(VectorCType(BaseCType(SymIntT)))
853                ),
854            },
855        ),
856        # replace to_args_scalartypes(self) with self_args_scalartypes
857        (
858            r"to_args_scalartypes\({}\)",
859            {
860                "suffix": "_args_scalartypes",
861                "nctype": lambda name: NamedCType(
862                    name, VectorCType(BaseCType(scalarTypeT))
863                ),
864            },
865        ),
866        # replace TensorGeometry(self) with self_geometry
867        (
868            r"TensorGeometry\({}\)",
869            {
870                "suffix": "_geometry",
871                "nctype": lambda name: NamedCType(name, BaseCType(tensorGeometryT)),
872            },
873        ),
874        (
875            r"{}.scalar_type\(\)",
876            {
877                "suffix": "_scalar_type",
878                "nctype": lambda name: NamedCType(name, BaseCType(scalarTypeT)),
879            },
880        ),
881        # replace self.dim() with self_dim
882        (
883            r"{}.dim\(\)",
884            {
885                "suffix": "_dim",
886                "nctype": lambda name: NamedCType(name, BaseCType(longT)),
887            },
888        ),
889        # replace self.sym_strides() with self_sym_strides
890        (
891            r"{}.sym_strides\(\)",
892            {
893                "suffix": "_sym_strides",
894                "nctype": lambda name: NamedCType(name, BaseCType(symIntArrayRefT)),
895                "expr": stride_expr,
896            },
897        ),
898        # replace self.layout() with self_layout
899        (
900            r"{}.layout\(\)",
901            {
902                "suffix": "_layout",
903                "nctype": lambda name: NamedCType(name, BaseCType(layoutT)),
904            },
905        ),
906        # replace self.is_conj() with self_conjugate
907        (
908            r"{}.is_conj\(\)",
909            {
910                "suffix": "_conjugate",
911                "nctype": lambda name: NamedCType(name, BaseCType(boolT)),
912            },
913        ),
914    ]
915
916    # find which arguments need to be saved
917    saved: list[SavedAttribute] = []
918
919    if ".sizes()" in formula or "->sizes()" in formula:
920        raise RuntimeError(
921            ".sizes() is not supported in derivative formulas. Instead, please use the SymInt version,"
922            + f".sym_sizes(), which returned a c10::SymIntArrayRef. formula={formula}"
923        )
924    if re.search(r"\.size\([-]?\d+\)", formula) or re.search(
925        r"->size\([-]?\d+\)", formula
926    ):
927        raise RuntimeError(
928            ".size(int) is not supported in derivative formulas. Instead, please use the SymInt version,"
929            + f".sym_size(int), which returned a c10::SymIntArrayRef. formula={formula}"
930        )
931    if ".strides()" in formula or "->strides()" in formula:
932        raise RuntimeError(
933            ".strides() is not supported in derivative formulas. Instead, please use the SymInt version,"
934            + f".sym_strides(), which returned a c10::SymIntArrayRef. formula={formula}"
935        )
936    for nctype in nctypes:
937        name = (
938            nctype.name.name if isinstance(nctype.name, SpecialArgName) else nctype.name
939        )
940        # First search the formula for expressions which can be evaluated
941        # when the autograd Function is created to avoid saving variables
942        for regex, info in REPLACEMENTS:
943
944            def repl(m: re.Match[str]) -> str:
945                suffix: str = (
946                    info["suffix"](m) if callable(info["suffix"]) else info["suffix"]
947                )
948                expr: str = info["expr"](name) if "expr" in info else m.group(0)
949                saved.append(
950                    SavedAttribute(
951                        nctype=info["nctype"](name + suffix),
952                        expr=expr,
953                    )
954                )
955                if "res" in info:
956                    replacement: str = info["res"](name)
957                    return replacement
958                return name + suffix
959
960            formula = re.sub(regex.format(name), repl, formula)
961
962        # std::optional<std::string> types stored in Backward nodes must be
963        # converted to std::optional<std::string_view> before being passed into
964        # the backward function
965        if nctype.type == OptionalCType(BaseCType(stringT)):
966            formula = re.sub(
967                rf"\b{name}\b",
968                f"{name}.has_value() ? std::optional<c10::string_view>({name}.value()) : std::nullopt",
969                formula,
970            )
971
972        # Find any variables which remain in the formula and save them
973        if re.search(IDENT_REGEX.format(name), formula):
974            saved.append(
975                SavedAttribute(
976                    nctype=nctype,
977                    expr=name,
978                )
979            )
980
981    return formula, tuple(saved)
982
983
984def _create_op_prefix(name: str) -> str:
985    """Takes a native function name converts to a op prefix name.
986
987    Note that the "name" parameter must be the native function name
988    without the optional variant suffix, so "add" instead of
989    "add.out".
990
991    OP names correspond to classes, hence the change to title case.
992
993    Example::
994    >>> _create_op_prefix('add')
995    'AddBackward'
996    """
997    camel_case = "".join([p.title() for p in name.split("_")])
998    return (camel_case + "Backward").replace("ForwardBackward", "Backward")
999
1000
1001def dedup_vars(vars: Sequence[SavedAttribute]) -> Sequence[SavedAttribute]:
1002    seen: set[str] = set()
1003    saved: list[SavedAttribute] = []
1004    for var in vars:
1005        name = (
1006            var.nctype.name.name
1007            if isinstance(var.nctype.name, SpecialArgName)
1008            else var.nctype.name
1009        )
1010        if name in seen:
1011            continue
1012        seen.add(name)
1013        saved.append(var)
1014    return saved
1015