xref: /aosp_15_r20/external/pytorch/tools/autograd/gen_inplace_or_view_type.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Generates ADInplaceOrViewType.h/cpp
2#
3# NOTE: If any changes are being made to the ADInplaceOrView codegen please also check
4# if updates are needed in torch/csrc/autograd/autograd_not_implemented_fallback.cpp
5# The fallback is expected to mimick this codegen, so we should keep the two in sync.
6
7from __future__ import annotations
8
9from torchgen.api import cpp
10from torchgen.api.autograd import (
11    dispatch_strategy,
12    gen_differentiable_outputs,
13    NativeFunctionWithDifferentiabilityInfo,
14)
15from torchgen.api.types import (
16    BaseCType,
17    Binding,
18    boolT,
19    ConstRefCType,
20    CType,
21    DispatcherSignature,
22    intArrayRefT,
23    longT,
24    OptionalCType,
25    symIntArrayRefT,
26    SymIntT,
27    tensorT,
28)
29from torchgen.code_template import CodeTemplate
30from torchgen.context import with_native_function
31from torchgen.model import (
32    NativeFunction,
33    SchemaKind,
34    SelfArgument,
35    TensorOptionsArguments,
36    Type,
37)
38from torchgen.utils import FileManager
39
40from .context import with_native_function_with_differentiability_info
41from .gen_trace_type import (
42    get_return_value,
43    MANUAL_AUTOGRAD,
44    tie_return_values,
45    type_wrapper_name,
46)
47
48
49# See NOTE [ Autograd View Variables ] in variable.h for details.
50# If you update list VIEW_FUNCTIONS or RETURNS_VIEWS_OF_INPUT,
51# you **MUST** also update the public list of view ops accordingly in
52# docs/source/tensor_view.rst. Note not all ATen functions are exposed to public,
53# e.g alias & sparse_coo_tensor_with_dims_and_tensors.
54#
55# A map: function name => name of the argument that all outputs are view of
56
57VIEW_FUNCTIONS_WITH_METADATA_CHANGE = [
58    "view_as_complex",
59    "view_as_real",
60    "_conj",
61    "_neg_view",
62    "_nested_get_values",
63    "_nested_view_from_buffer",
64    "_nested_view_from_jagged",
65]
66
67VIEW_FUNCTIONS = {
68    "numpy_T": "self",
69    "alias": "self",
70    "as_strided": "self",
71    "diagonal": "self",
72    "expand": "self",
73    "permute": "self",
74    "select": "self",
75    "slice": "self",
76    "slice_inverse": "self",
77    "split": "self",
78    "split_with_sizes": "self",
79    "squeeze": "self",
80    "t": "self",
81    "transpose": "self",
82    "unfold": "self",
83    "unsqueeze": "self",
84    "flatten": "self",
85    "view": "self",
86    "unbind": "self",
87    "_indices": "self",
88    "_values": "self",
89    "indices": "self",
90    "values": "self",
91    "crow_indices": "self",
92    "col_indices": "self",
93    "ccol_indices": "self",
94    "row_indices": "self",
95    # sparse_coo ctor output should really be views of both indices and values,
96    # but we only supports making as view of a single variable, and indices is
97    # discrete anyways.
98    # FIXME: clone indices on construction.
99    "sparse_coo_tensor_with_dims_and_tensors": "values",
100    "_reshape_alias": "self",
101    "_test_autograd_multiple_dispatch_view": "self",
102}
103
104for key in VIEW_FUNCTIONS_WITH_METADATA_CHANGE:
105    VIEW_FUNCTIONS[key] = "self"
106
107# note: some VIEW_FUNCTIONS are just compositions of the view functions above
108# this list contains both the root view functions and any that are purely composed
109# of viewing functions, and is used by the JIT to determine when an operator
110# may return a view of its inputs; however they may sometimes return a copy.
111# (e.g. `contiguous`)
112RETURNS_VIEWS_OF_INPUT = set(VIEW_FUNCTIONS.keys()).union(
113    {
114        "chunk",
115        "detach",
116        "contiguous",
117        "reshape",
118        "reshape_as",
119        "expand_as",
120        "view_as",
121        "real",
122        "imag",
123        "narrow",
124        "movedim",
125        "tensor_split",
126        "swapdims",
127        "swapaxes",
128        "mT",
129        "mH",
130        "adjoint",
131        "matrix_H",
132    }
133)
134
135# These are the functions we consider views for the purposes of validating
136# StorageImpl and TensorImpl in gen_variable_type.
137# `_unsafe_view` is not included in VIEW_FUNCTIONS above because it is not a
138# view for the purposes of ADInplaceOrView kernel, we do not want to call as_view
139# See NOTE [Unsafe View] for more info.
140ALL_VIEW_FUNCTIONS = {
141    **VIEW_FUNCTIONS,
142    "_unsafe_view": "self",
143}
144
145ARRAYREF_TO_VEC = CodeTemplate(
146    """\
147auto ${vec} = ${arg}.vec();
148"""
149)
150
151OPTIONAL_TO_VAL = CodeTemplate(
152    """\
153auto ${val} = ${arg}.value_or(${default});
154"""
155)
156
157CALL_DISPATCH = CodeTemplate(
158    """\
159at::_ops::${unambiguous_name}::call(${unpacked_args})"""
160)
161
162REVERSE_VIEW_DISPATCH = CodeTemplate(
163    """\
164${reverse_name}(${unpacked_args})"""
165)
166
167MULTI_OUTPUT_VIEW_ITERATION = CodeTemplate(
168    """\
169for (auto ${view_idx} : c10::irange(${var}.size())) {
170  ${body}
171}
172"""
173)
174
175SETUP_REPLAY_VIEW_IF_NOT_SUPPORT_AS_STRIDED_OR_VIEW_WITH_METADATA_CHANGE = CodeTemplate(
176    """\
177std::unique_ptr<torch::autograd::ViewFunc> func(nullptr);
178std::function<at::Tensor(const at::Tensor&)> rev_func=nullptr;
179if (${is_view_with_metadata_change} ||
180    !self.unsafeGetTensorImpl()->support_as_strided() ||
181    self.unsafeGetTensorImpl()->is_python_dispatch() ||
182    c10::AutogradState::get_tls_state().get_view_replay_enabled()) {
183  ${replay_view_func}
184  ${reverse_replay_view_func}
185}
186"""
187)
188
189REPLAY_VIEW_FUNC = CodeTemplate(
190    """\
191func = std::make_unique<${view_func_name}>(${view_func_args});
192"""
193)
194
195REVERSE_REPLAY_VIEW_LAMBDA_FUNC = CodeTemplate(
196    """\
197rev_func = [=](const at::Tensor& ${input_view}) {
198  return ${reverse_replay_view_call};
199};
200"""
201)
202
203METHOD_DEFINITION = CodeTemplate(
204    """\
205${return_type} ${type_wrapper_name}(${formals}) {
206  ${type_definition_body}
207}
208"""
209)
210
211WRAPPER_REGISTRATION = CodeTemplate(
212    """\
213m.impl("${unqual_operator_name_with_overload}",
214       TORCH_FN(${class_type}::${type_wrapper_name})
215);
216"""
217)
218
219AUTOGRAD_NOT_IMPLEMENTED_REGISTRATION = CodeTemplate(
220    """\
221m.impl("${unqual_operator_name_with_overload}", torch::autograd::autogradNotImplementedFallback());
222"""
223)
224
225INPLACE_REDISPATCH = CodeTemplate(
226    """\
227{
228  at::AutoDispatchBelowADInplaceOrView guard;
229  at::_ops::${unambiguous_name}::redispatch(${unpacked_args});
230}
231"""
232)
233
234ASSIGN_RETURN_VALUE = CodeTemplate(
235    """\
236${return_values} = ${rhs_value};
237"""
238)
239
240VIEW_REDISPATCH = CodeTemplate(
241    """\
242${assign_return_values} ([&]() {
243  at::AutoDispatchBelowADInplaceOrView guard;
244  return at::_ops::${unambiguous_name}::redispatch(${unpacked_args});
245})();
246"""
247)
248
249TMP_VAR = "_tmp"
250
251
252# FIXME: Ideally these functions should be methods on Type class, but we have a
253#        comment in codegen/model.py there saying these concepts are not well defined.
254#        Thus we put a version that commonly used by autograd codegen here.
255def is_tensor_type(t: Type) -> bool:
256    # TODO: Should handle optional here?
257    return t.is_tensor_like() and t.is_list_like() is None
258
259
260def is_tensor_list_type(t: Type) -> bool:
261    # TODO: Should handle optional here?
262    return t.is_tensor_like() and t.is_list_like() is not None
263
264
265UNPACK_TENSOR = CodeTemplate(
266    """\
267auto${ref} ${arg_name}_ = unpack${suffix}(${arg_name}, "${arg_name}", ${arg_pos});"""
268)
269
270
271def unpacked_name(arg_name: str) -> str:
272    return arg_name + "_"
273
274
275# e.g. select.int -> select_copy_int_inverse()
276def inverse_view_name(f: NativeFunction) -> str:
277    copy_variant = f"{f.root_name}_copy"
278    overload = f"{f.func.name.overload_name}"
279    if overload != "":
280        overload = "_" + overload
281    return f"{copy_variant}{overload}_inverse"
282
283
284def extract_bindings(f: NativeFunction) -> list[Binding]:
285    return [
286        r
287        for a in f.func.schema_order_arguments()
288        for r in cpp.argument(
289            a,
290            method=False,
291            symint=True,
292            cpp_no_default_args=set(),
293            faithful=False,
294            has_tensor_options=False,
295        )
296    ]
297
298
299@with_native_function
300def unpack_args(f: NativeFunction) -> tuple[list[str], list[Binding]]:
301    body: list[str] = []
302    unpacked_bindings: list[Binding] = []
303
304    for i, binding in enumerate(extract_bindings(f)):
305        assert not isinstance(binding.argument, SelfArgument)
306        if isinstance(binding.argument, TensorOptionsArguments):
307            raise RuntimeError("VariableKernel shouldn't take TensorOptions")
308
309        is_nullable = binding.argument.type.is_nullable()
310        if not binding.argument.type.is_tensor_like() or is_nullable:
311            unpacked_bindings.append(binding)
312            continue
313
314        is_tensor_list = is_tensor_list_type(binding.argument.type)
315        ref = (not is_nullable) and not is_tensor_list
316        suffix = "_opt" if is_nullable and not is_tensor_list else ""
317        body.append(
318            UNPACK_TENSOR.substitute(
319                arg_name=binding.name,
320                arg_pos=i,
321                suffix=suffix,
322                ref="&" if ref else "",
323            )
324        )
325        unpacked_bindings.append(
326            Binding(
327                name=unpacked_name(binding.name),
328                nctype=binding.nctype,
329                argument=binding.argument,
330                default=binding.default,
331            )
332        )
333
334    return body, unpacked_bindings
335
336
337def get_base_name(f: NativeFunction) -> str:
338    return f.func.name.name.base  # TODO: should be str(f.func.name.name)?
339
340
341def get_view_info(f: NativeFunction) -> str | None:
342    base_name = get_base_name(f)
343    view_info = VIEW_FUNCTIONS.get(base_name, None)
344    if view_info is None and base_name in RETURNS_VIEWS_OF_INPUT:
345        view_info = "self"
346    return view_info
347
348
349def emit_view_func(
350    f: NativeFunction, bindings: list[Binding], view_idx: str | None = None
351) -> str:
352    """Generate an additional lambda function to recover views in backward when as_strided is not supported.
353    See Note [View + Inplace update for base tensor] and [View + Inplace update for view tensor] for more details.
354    """
355    # TODO: Clean this logic up if we get rid of reverse view funcs or reify them.
356    input_base = "input_base"
357    replay_view_func = ""
358    updated_args: list[str] = []
359    known_view_arg_simple_types: list[CType] = [
360        BaseCType(longT),
361        OptionalCType(BaseCType(longT)),
362        BaseCType(SymIntT),
363        OptionalCType(BaseCType(SymIntT)),
364        BaseCType(boolT),
365        BaseCType(intArrayRefT),
366        BaseCType(symIntArrayRefT),
367        ConstRefCType(BaseCType(tensorT)),
368        ConstRefCType(OptionalCType(BaseCType(tensorT))),
369    ]
370    for binding in bindings:
371        arg, arg_type = binding.name, binding.nctype.type
372        if arg == "self":
373            updated_args.append(input_base)
374            continue
375        if arg_type not in known_view_arg_simple_types:
376            known_types_str = ", ".join([str(t) for t in known_view_arg_simple_types])
377            raise TypeError(
378                f"You are adding an {arg_type} {arg} argument to op {cpp.name(f.func)} in addition to known types: "
379                f"{known_types_str}. Please update the list or materialize it so that it can be closed "
380                "over by value, also add a test in pytorch/xla/test/test_operations.py where this code "
381                "is exercised."
382            )
383        if arg_type == BaseCType(intArrayRefT) or arg_type == BaseCType(
384            symIntArrayRefT
385        ):
386            # It's not safe to close over IntArrayRef by value, since this is a
387            # reference type, so materialize a vector to close over by value
388            arg_vec = arg + "_vec"
389            replay_view_func += ARRAYREF_TO_VEC.substitute(arg=arg, vec=arg_vec)
390            updated_args.append(arg_vec)
391        elif arg_type == OptionalCType(BaseCType(longT)):
392            # Materialize int64_t? to int64_t
393            arg_value = arg + "_val"
394            replay_view_func += OPTIONAL_TO_VAL.substitute(
395                arg=arg, val=arg_value, default="0"
396            )
397            updated_args.append(arg_value)
398        elif arg_type == ConstRefCType(BaseCType(tensorT)) or arg_type == ConstRefCType(
399            OptionalCType(BaseCType(tensorT))
400        ):
401            # NB: Closing over a tensor. If a user modifies this tensor, this will be silently
402            # incorrect. The proper thing to do is to store the version counter and copy on write.
403            updated_args.append(arg)
404        else:
405            updated_args.append(arg)
406
407    from .gen_view_funcs import view_func_name
408
409    view_func_args = [b.name for b in bindings if b.name != "self"]
410    if view_idx is not None:
411        view_func_args.append(f"{view_idx}")
412    replay_view_func += REPLAY_VIEW_FUNC.substitute(
413        view_func_name=view_func_name(f, include_namespace=True),
414        view_func_args=view_func_args,
415    )
416
417    input_view = "input_view"
418    reverse_unpacked_args = [
419        "self",
420        f"{input_view}",
421        # inverse_return_mode=
422        "at::functionalization::InverseReturnMode::AlwaysView",
423        *(() if view_idx is None else (f"{view_idx}",)),
424        # skip input_base arg
425        *updated_args[1:],
426    ]
427
428    from torchgen.api.functionalization import reverse_name
429
430    reverse_replay_view_call = REVERSE_VIEW_DISPATCH.substitute(
431        reverse_name=reverse_name(f, include_namespace=True),
432        unpacked_args=reverse_unpacked_args,
433    )
434    reverse_replay_view_func = REVERSE_REPLAY_VIEW_LAMBDA_FUNC.substitute(
435        input_view=input_view, reverse_replay_view_call=reverse_replay_view_call
436    )
437
438    is_view_with_metadata_change = (
439        "true" if cpp.name(f.func) in VIEW_FUNCTIONS_WITH_METADATA_CHANGE else "false"
440    )
441
442    return SETUP_REPLAY_VIEW_IF_NOT_SUPPORT_AS_STRIDED_OR_VIEW_WITH_METADATA_CHANGE.substitute(
443        is_view_with_metadata_change=is_view_with_metadata_change,
444        replay_view_func=replay_view_func,
445        reverse_replay_view_func=reverse_replay_view_func,
446    )
447
448
449def emit_view_body(
450    fn: NativeFunctionWithDifferentiabilityInfo, var: str
451) -> tuple[str, str]:
452    # See NOTE [ Autograd View Variables ] in variable.h for details.
453    f = fn.func
454    base_name = get_base_name(f)
455    view_info = get_view_info(f)
456    call = ""
457    differentiable_outputs = gen_differentiable_outputs(fn)
458    differentiable_output_vars = {r.name for r in differentiable_outputs}
459    if not isinstance(view_info, str):
460        raise TypeError(
461            f"The view info should be a string for {base_name}, but it is: {view_info}"
462        )
463    if len(differentiable_output_vars) == 0:
464        # no output is differentiable (.indices() for SparseTensors for example)
465        rhs_value = (
466            f"as_view({view_info}, {var}, "
467            f"/* is_bw_differentiable */ false, /* is_fw_differentiable */ false)"
468        )
469    elif len(differentiable_output_vars) == 1:
470        # Single differentiable output (Tensor or Tensor[])
471        return_info = differentiable_outputs[0]
472        # We only support simple Tensor or a TensorList for functions that return views
473        if not is_tensor_type(return_info.type) and not is_tensor_list_type(
474            return_info.type
475        ):
476            raise RuntimeError(
477                f"{base_name} that return differentiable views can only return Tensor or Tensor[]"
478            )
479
480        # See Note [ View + Inplace detection]
481        def get_creation_meta_in_mode(original: str) -> str:
482            creation_meta_with_grad_mode = f"(at::GradMode::is_enabled() ? {original} : CreationMeta::NO_GRAD_MODE)"
483            return f"InferenceMode::is_enabled() ? CreationMeta::INFERENCE_MODE : {creation_meta_with_grad_mode}"
484
485        # Only allow rebasing of the history if we return a single Tensor
486        # If we are in a no grad block, raise a warning
487        # See NOTE [ View + Inplace detection ] for more details about this logic
488        if is_tensor_list_type(return_info.type):
489            creation_meta = get_creation_meta_in_mode("CreationMeta::MULTI_OUTPUT_NODE")
490            view_idx = "view_idx"
491            view_func = emit_view_func(
492                f, extract_bindings(f), view_idx=view_idx
493            ).strip()
494            as_view_call = (
495                f"as_view(/* base */ {view_info}, /* output */ {var}[{view_idx}], "
496                "/* is_bw_differentiable */ true, /* is_fw_differentiable */ true, "
497                "/* view_func */ std::move(func), /* rev_view_func */ rev_func, "
498                f"/* creation_meta */ {creation_meta});"
499            )
500            call += MULTI_OUTPUT_VIEW_ITERATION.substitute(
501                var=var, view_idx=view_idx, body=f"{view_func}\n{as_view_call}"
502            )
503            rhs_value = f"std::move({var})"
504        else:
505            call += emit_view_func(f, extract_bindings(f), view_idx=None)
506            creation_meta = get_creation_meta_in_mode("CreationMeta::DEFAULT")
507            rhs_value = (
508                f"as_view(/* base */ {view_info}, /* output */ {var}, /* is_bw_differentiable */ true, "
509                "/* is_fw_differentiable */ true, "
510                f"/* view_func */ std::move(func), /* rev_view_func */ rev_func, /* creation_meta */ {creation_meta})"
511            )
512    else:
513        # This could be supported but we don't need it at the moment, so keeping things simple.
514        raise RuntimeError(
515            "Function that return multiple differentiable output "
516            "when at least one of them is view is not supported."
517        )
518    return call, rhs_value
519
520
521def modifies_arguments(f: NativeFunction) -> bool:
522    return f.func.kind() in [SchemaKind.inplace, SchemaKind.out]
523
524
525@with_native_function_with_differentiability_info
526def emit_inplace_or_view_body(fn: NativeFunctionWithDifferentiabilityInfo) -> list[str]:
527    f = fn.func
528    inplace_view_body: list[str] = []
529
530    dispatcher_sig = DispatcherSignature.from_schema(f.func)
531    dispatcher_exprs = dispatcher_sig.exprs()
532
533    # code-generated ADInplaceOrView kernels plumb and recompute dispatch keys directly through the kernel for performance.
534    # See Note [Plumbing Keys Through The Dispatcher] for details.
535    dispatch_key_set = "ks & c10::after_ADInplaceOrView_keyset"
536    redispatch_args = ", ".join([dispatch_key_set] + [a.expr for a in dispatcher_exprs])
537
538    # Note that this calls the slow, dispatching variants of manual_cpp_binding ops.
539    # We could probably work harder to ensure that the fast variants are called instead, but the perf benefit would be minimal.
540    if modifies_arguments(f):  # inplace op
541        inplace_view_body.append(
542            INPLACE_REDISPATCH.substitute(
543                unambiguous_name=f.func.name.unambiguous_name(),
544                unpacked_args=redispatch_args,
545            )
546        )
547        for r in cpp.return_names(f):
548            inplace_view_body.append(f"increment_version({r});")
549    else:
550        assert get_view_info(f) is not None
551        inplace_view_body.append(
552            VIEW_REDISPATCH.substitute(
553                assign_return_values="auto " + TMP_VAR + " = ",
554                unambiguous_name=f.func.name.unambiguous_name(),
555                unpacked_args=redispatch_args,
556            )
557        )
558        call, rhs_value = emit_view_body(fn, TMP_VAR)
559        inplace_view_body.append(call)
560        assert rhs_value is not None
561        inplace_view_body.append(
562            ASSIGN_RETURN_VALUE.substitute(
563                return_values=tie_return_values(f), rhs_value=rhs_value
564            )
565        )
566    if f.func.returns:
567        inplace_view_body.append(f"return {get_return_value(f)};")
568    return inplace_view_body
569
570
571@with_native_function
572def gen_formals(f: NativeFunction) -> str:
573    return ", ".join(
574        # code-generated autograd kernels plumb and recompute dispatch keys directly through the kernel for performance.
575        # See Note [Plumbing Keys Through The Dispatcher] for details.
576        ["c10::DispatchKeySet ks"]
577        + [
578            f'{cpp.argument_type(a, binds="__placeholder__", symint=True).cpp_type()} {a.name}'
579            for a in f.func.schema_order_arguments()
580        ]
581    )
582
583
584@with_native_function_with_differentiability_info
585def inplace_or_view_method_definition(
586    fn: NativeFunctionWithDifferentiabilityInfo,
587) -> str | None:
588    f = fn.func
589    if get_view_info(f) is None and (
590        # For functions that modify their inputs but don't return them,
591        # we can't give them autograd support.
592        # See https://github.com/pytorch/pytorch/issues/53796
593        not modifies_arguments(f)
594        or len(f.func.returns) == 0
595    ):
596        return None
597    return METHOD_DEFINITION.substitute(
598        return_type=cpp.returns_type(f.func.returns, symint=True).cpp_type(),
599        type_wrapper_name=type_wrapper_name(f),
600        formals=gen_formals(f),
601        type_definition_body=emit_inplace_or_view_body(fn),
602    )
603
604
605@with_native_function_with_differentiability_info
606def inplace_or_view_method_registration(
607    fn: NativeFunctionWithDifferentiabilityInfo,
608) -> str | None:
609    f = fn.func
610    if get_view_info(f) is None and (
611        not modifies_arguments(f) or len(f.func.returns) == 0
612    ):
613        return None
614    return WRAPPER_REGISTRATION.substitute(
615        unqual_operator_name_with_overload=f.func.name,
616        type_wrapper_name=type_wrapper_name(f),
617        class_type="ADInplaceOrView",
618    )
619
620
621def use_derived(fn: NativeFunctionWithDifferentiabilityInfo) -> bool:
622    f = fn.func
623    name = cpp.name(f.func)
624    return name not in MANUAL_AUTOGRAD and dispatch_strategy(fn) == "use_derived"
625
626
627def gen_inplace_or_view_type_env(
628    fn: NativeFunctionWithDifferentiabilityInfo,
629) -> dict[str, list[str]]:
630    definition = inplace_or_view_method_definition(fn)
631    registration = inplace_or_view_method_registration(fn)
632
633    return {
634        "ops_headers": (
635            [f"#include <ATen/ops/{fn.func.root_name}_ops.h>"]
636            if definition is not None
637            else []
638        ),
639        "inplace_or_view_method_definitions": [definition]
640        if definition is not None
641        else [],
642        "inplace_or_view_wrapper_registrations": [registration]
643        if registration is not None
644        else [],
645    }
646
647
648def gen_inplace_or_view_type(
649    out: str,
650    native_yaml_path: str,
651    tags_yaml_path: str,
652    fns_with_infos: list[NativeFunctionWithDifferentiabilityInfo],
653    template_path: str,
654) -> None:
655    # NOTE: see Note [Sharded File] at the top of the VariableType.cpp
656    # template regarding sharding of the generated files.
657    num_shards = 2
658
659    fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False)
660    fm.write_sharded(
661        "ADInplaceOrViewType.cpp",
662        [fn for fn in fns_with_infos if use_derived(fn)],
663        key_fn=lambda fn: fn.func.root_name,
664        base_env={
665            "generated_comment": "@"
666            + f"generated from {fm.template_dir_for_comments()}/ADInplaceOrViewType.cpp",
667        },
668        env_callable=gen_inplace_or_view_type_env,
669        num_shards=2,
670        sharded_keys={
671            "ops_headers",
672            "inplace_or_view_method_definitions",
673            "inplace_or_view_wrapper_registrations",
674        },
675    )
676