1# Generates ADInplaceOrViewType.h/cpp 2# 3# NOTE: If any changes are being made to the ADInplaceOrView codegen please also check 4# if updates are needed in torch/csrc/autograd/autograd_not_implemented_fallback.cpp 5# The fallback is expected to mimick this codegen, so we should keep the two in sync. 6 7from __future__ import annotations 8 9from torchgen.api import cpp 10from torchgen.api.autograd import ( 11 dispatch_strategy, 12 gen_differentiable_outputs, 13 NativeFunctionWithDifferentiabilityInfo, 14) 15from torchgen.api.types import ( 16 BaseCType, 17 Binding, 18 boolT, 19 ConstRefCType, 20 CType, 21 DispatcherSignature, 22 intArrayRefT, 23 longT, 24 OptionalCType, 25 symIntArrayRefT, 26 SymIntT, 27 tensorT, 28) 29from torchgen.code_template import CodeTemplate 30from torchgen.context import with_native_function 31from torchgen.model import ( 32 NativeFunction, 33 SchemaKind, 34 SelfArgument, 35 TensorOptionsArguments, 36 Type, 37) 38from torchgen.utils import FileManager 39 40from .context import with_native_function_with_differentiability_info 41from .gen_trace_type import ( 42 get_return_value, 43 MANUAL_AUTOGRAD, 44 tie_return_values, 45 type_wrapper_name, 46) 47 48 49# See NOTE [ Autograd View Variables ] in variable.h for details. 50# If you update list VIEW_FUNCTIONS or RETURNS_VIEWS_OF_INPUT, 51# you **MUST** also update the public list of view ops accordingly in 52# docs/source/tensor_view.rst. Note not all ATen functions are exposed to public, 53# e.g alias & sparse_coo_tensor_with_dims_and_tensors. 54# 55# A map: function name => name of the argument that all outputs are view of 56 57VIEW_FUNCTIONS_WITH_METADATA_CHANGE = [ 58 "view_as_complex", 59 "view_as_real", 60 "_conj", 61 "_neg_view", 62 "_nested_get_values", 63 "_nested_view_from_buffer", 64 "_nested_view_from_jagged", 65] 66 67VIEW_FUNCTIONS = { 68 "numpy_T": "self", 69 "alias": "self", 70 "as_strided": "self", 71 "diagonal": "self", 72 "expand": "self", 73 "permute": "self", 74 "select": "self", 75 "slice": "self", 76 "slice_inverse": "self", 77 "split": "self", 78 "split_with_sizes": "self", 79 "squeeze": "self", 80 "t": "self", 81 "transpose": "self", 82 "unfold": "self", 83 "unsqueeze": "self", 84 "flatten": "self", 85 "view": "self", 86 "unbind": "self", 87 "_indices": "self", 88 "_values": "self", 89 "indices": "self", 90 "values": "self", 91 "crow_indices": "self", 92 "col_indices": "self", 93 "ccol_indices": "self", 94 "row_indices": "self", 95 # sparse_coo ctor output should really be views of both indices and values, 96 # but we only supports making as view of a single variable, and indices is 97 # discrete anyways. 98 # FIXME: clone indices on construction. 99 "sparse_coo_tensor_with_dims_and_tensors": "values", 100 "_reshape_alias": "self", 101 "_test_autograd_multiple_dispatch_view": "self", 102} 103 104for key in VIEW_FUNCTIONS_WITH_METADATA_CHANGE: 105 VIEW_FUNCTIONS[key] = "self" 106 107# note: some VIEW_FUNCTIONS are just compositions of the view functions above 108# this list contains both the root view functions and any that are purely composed 109# of viewing functions, and is used by the JIT to determine when an operator 110# may return a view of its inputs; however they may sometimes return a copy. 111# (e.g. `contiguous`) 112RETURNS_VIEWS_OF_INPUT = set(VIEW_FUNCTIONS.keys()).union( 113 { 114 "chunk", 115 "detach", 116 "contiguous", 117 "reshape", 118 "reshape_as", 119 "expand_as", 120 "view_as", 121 "real", 122 "imag", 123 "narrow", 124 "movedim", 125 "tensor_split", 126 "swapdims", 127 "swapaxes", 128 "mT", 129 "mH", 130 "adjoint", 131 "matrix_H", 132 } 133) 134 135# These are the functions we consider views for the purposes of validating 136# StorageImpl and TensorImpl in gen_variable_type. 137# `_unsafe_view` is not included in VIEW_FUNCTIONS above because it is not a 138# view for the purposes of ADInplaceOrView kernel, we do not want to call as_view 139# See NOTE [Unsafe View] for more info. 140ALL_VIEW_FUNCTIONS = { 141 **VIEW_FUNCTIONS, 142 "_unsafe_view": "self", 143} 144 145ARRAYREF_TO_VEC = CodeTemplate( 146 """\ 147auto ${vec} = ${arg}.vec(); 148""" 149) 150 151OPTIONAL_TO_VAL = CodeTemplate( 152 """\ 153auto ${val} = ${arg}.value_or(${default}); 154""" 155) 156 157CALL_DISPATCH = CodeTemplate( 158 """\ 159at::_ops::${unambiguous_name}::call(${unpacked_args})""" 160) 161 162REVERSE_VIEW_DISPATCH = CodeTemplate( 163 """\ 164${reverse_name}(${unpacked_args})""" 165) 166 167MULTI_OUTPUT_VIEW_ITERATION = CodeTemplate( 168 """\ 169for (auto ${view_idx} : c10::irange(${var}.size())) { 170 ${body} 171} 172""" 173) 174 175SETUP_REPLAY_VIEW_IF_NOT_SUPPORT_AS_STRIDED_OR_VIEW_WITH_METADATA_CHANGE = CodeTemplate( 176 """\ 177std::unique_ptr<torch::autograd::ViewFunc> func(nullptr); 178std::function<at::Tensor(const at::Tensor&)> rev_func=nullptr; 179if (${is_view_with_metadata_change} || 180 !self.unsafeGetTensorImpl()->support_as_strided() || 181 self.unsafeGetTensorImpl()->is_python_dispatch() || 182 c10::AutogradState::get_tls_state().get_view_replay_enabled()) { 183 ${replay_view_func} 184 ${reverse_replay_view_func} 185} 186""" 187) 188 189REPLAY_VIEW_FUNC = CodeTemplate( 190 """\ 191func = std::make_unique<${view_func_name}>(${view_func_args}); 192""" 193) 194 195REVERSE_REPLAY_VIEW_LAMBDA_FUNC = CodeTemplate( 196 """\ 197rev_func = [=](const at::Tensor& ${input_view}) { 198 return ${reverse_replay_view_call}; 199}; 200""" 201) 202 203METHOD_DEFINITION = CodeTemplate( 204 """\ 205${return_type} ${type_wrapper_name}(${formals}) { 206 ${type_definition_body} 207} 208""" 209) 210 211WRAPPER_REGISTRATION = CodeTemplate( 212 """\ 213m.impl("${unqual_operator_name_with_overload}", 214 TORCH_FN(${class_type}::${type_wrapper_name}) 215); 216""" 217) 218 219AUTOGRAD_NOT_IMPLEMENTED_REGISTRATION = CodeTemplate( 220 """\ 221m.impl("${unqual_operator_name_with_overload}", torch::autograd::autogradNotImplementedFallback()); 222""" 223) 224 225INPLACE_REDISPATCH = CodeTemplate( 226 """\ 227{ 228 at::AutoDispatchBelowADInplaceOrView guard; 229 at::_ops::${unambiguous_name}::redispatch(${unpacked_args}); 230} 231""" 232) 233 234ASSIGN_RETURN_VALUE = CodeTemplate( 235 """\ 236${return_values} = ${rhs_value}; 237""" 238) 239 240VIEW_REDISPATCH = CodeTemplate( 241 """\ 242${assign_return_values} ([&]() { 243 at::AutoDispatchBelowADInplaceOrView guard; 244 return at::_ops::${unambiguous_name}::redispatch(${unpacked_args}); 245})(); 246""" 247) 248 249TMP_VAR = "_tmp" 250 251 252# FIXME: Ideally these functions should be methods on Type class, but we have a 253# comment in codegen/model.py there saying these concepts are not well defined. 254# Thus we put a version that commonly used by autograd codegen here. 255def is_tensor_type(t: Type) -> bool: 256 # TODO: Should handle optional here? 257 return t.is_tensor_like() and t.is_list_like() is None 258 259 260def is_tensor_list_type(t: Type) -> bool: 261 # TODO: Should handle optional here? 262 return t.is_tensor_like() and t.is_list_like() is not None 263 264 265UNPACK_TENSOR = CodeTemplate( 266 """\ 267auto${ref} ${arg_name}_ = unpack${suffix}(${arg_name}, "${arg_name}", ${arg_pos});""" 268) 269 270 271def unpacked_name(arg_name: str) -> str: 272 return arg_name + "_" 273 274 275# e.g. select.int -> select_copy_int_inverse() 276def inverse_view_name(f: NativeFunction) -> str: 277 copy_variant = f"{f.root_name}_copy" 278 overload = f"{f.func.name.overload_name}" 279 if overload != "": 280 overload = "_" + overload 281 return f"{copy_variant}{overload}_inverse" 282 283 284def extract_bindings(f: NativeFunction) -> list[Binding]: 285 return [ 286 r 287 for a in f.func.schema_order_arguments() 288 for r in cpp.argument( 289 a, 290 method=False, 291 symint=True, 292 cpp_no_default_args=set(), 293 faithful=False, 294 has_tensor_options=False, 295 ) 296 ] 297 298 299@with_native_function 300def unpack_args(f: NativeFunction) -> tuple[list[str], list[Binding]]: 301 body: list[str] = [] 302 unpacked_bindings: list[Binding] = [] 303 304 for i, binding in enumerate(extract_bindings(f)): 305 assert not isinstance(binding.argument, SelfArgument) 306 if isinstance(binding.argument, TensorOptionsArguments): 307 raise RuntimeError("VariableKernel shouldn't take TensorOptions") 308 309 is_nullable = binding.argument.type.is_nullable() 310 if not binding.argument.type.is_tensor_like() or is_nullable: 311 unpacked_bindings.append(binding) 312 continue 313 314 is_tensor_list = is_tensor_list_type(binding.argument.type) 315 ref = (not is_nullable) and not is_tensor_list 316 suffix = "_opt" if is_nullable and not is_tensor_list else "" 317 body.append( 318 UNPACK_TENSOR.substitute( 319 arg_name=binding.name, 320 arg_pos=i, 321 suffix=suffix, 322 ref="&" if ref else "", 323 ) 324 ) 325 unpacked_bindings.append( 326 Binding( 327 name=unpacked_name(binding.name), 328 nctype=binding.nctype, 329 argument=binding.argument, 330 default=binding.default, 331 ) 332 ) 333 334 return body, unpacked_bindings 335 336 337def get_base_name(f: NativeFunction) -> str: 338 return f.func.name.name.base # TODO: should be str(f.func.name.name)? 339 340 341def get_view_info(f: NativeFunction) -> str | None: 342 base_name = get_base_name(f) 343 view_info = VIEW_FUNCTIONS.get(base_name, None) 344 if view_info is None and base_name in RETURNS_VIEWS_OF_INPUT: 345 view_info = "self" 346 return view_info 347 348 349def emit_view_func( 350 f: NativeFunction, bindings: list[Binding], view_idx: str | None = None 351) -> str: 352 """Generate an additional lambda function to recover views in backward when as_strided is not supported. 353 See Note [View + Inplace update for base tensor] and [View + Inplace update for view tensor] for more details. 354 """ 355 # TODO: Clean this logic up if we get rid of reverse view funcs or reify them. 356 input_base = "input_base" 357 replay_view_func = "" 358 updated_args: list[str] = [] 359 known_view_arg_simple_types: list[CType] = [ 360 BaseCType(longT), 361 OptionalCType(BaseCType(longT)), 362 BaseCType(SymIntT), 363 OptionalCType(BaseCType(SymIntT)), 364 BaseCType(boolT), 365 BaseCType(intArrayRefT), 366 BaseCType(symIntArrayRefT), 367 ConstRefCType(BaseCType(tensorT)), 368 ConstRefCType(OptionalCType(BaseCType(tensorT))), 369 ] 370 for binding in bindings: 371 arg, arg_type = binding.name, binding.nctype.type 372 if arg == "self": 373 updated_args.append(input_base) 374 continue 375 if arg_type not in known_view_arg_simple_types: 376 known_types_str = ", ".join([str(t) for t in known_view_arg_simple_types]) 377 raise TypeError( 378 f"You are adding an {arg_type} {arg} argument to op {cpp.name(f.func)} in addition to known types: " 379 f"{known_types_str}. Please update the list or materialize it so that it can be closed " 380 "over by value, also add a test in pytorch/xla/test/test_operations.py where this code " 381 "is exercised." 382 ) 383 if arg_type == BaseCType(intArrayRefT) or arg_type == BaseCType( 384 symIntArrayRefT 385 ): 386 # It's not safe to close over IntArrayRef by value, since this is a 387 # reference type, so materialize a vector to close over by value 388 arg_vec = arg + "_vec" 389 replay_view_func += ARRAYREF_TO_VEC.substitute(arg=arg, vec=arg_vec) 390 updated_args.append(arg_vec) 391 elif arg_type == OptionalCType(BaseCType(longT)): 392 # Materialize int64_t? to int64_t 393 arg_value = arg + "_val" 394 replay_view_func += OPTIONAL_TO_VAL.substitute( 395 arg=arg, val=arg_value, default="0" 396 ) 397 updated_args.append(arg_value) 398 elif arg_type == ConstRefCType(BaseCType(tensorT)) or arg_type == ConstRefCType( 399 OptionalCType(BaseCType(tensorT)) 400 ): 401 # NB: Closing over a tensor. If a user modifies this tensor, this will be silently 402 # incorrect. The proper thing to do is to store the version counter and copy on write. 403 updated_args.append(arg) 404 else: 405 updated_args.append(arg) 406 407 from .gen_view_funcs import view_func_name 408 409 view_func_args = [b.name for b in bindings if b.name != "self"] 410 if view_idx is not None: 411 view_func_args.append(f"{view_idx}") 412 replay_view_func += REPLAY_VIEW_FUNC.substitute( 413 view_func_name=view_func_name(f, include_namespace=True), 414 view_func_args=view_func_args, 415 ) 416 417 input_view = "input_view" 418 reverse_unpacked_args = [ 419 "self", 420 f"{input_view}", 421 # inverse_return_mode= 422 "at::functionalization::InverseReturnMode::AlwaysView", 423 *(() if view_idx is None else (f"{view_idx}",)), 424 # skip input_base arg 425 *updated_args[1:], 426 ] 427 428 from torchgen.api.functionalization import reverse_name 429 430 reverse_replay_view_call = REVERSE_VIEW_DISPATCH.substitute( 431 reverse_name=reverse_name(f, include_namespace=True), 432 unpacked_args=reverse_unpacked_args, 433 ) 434 reverse_replay_view_func = REVERSE_REPLAY_VIEW_LAMBDA_FUNC.substitute( 435 input_view=input_view, reverse_replay_view_call=reverse_replay_view_call 436 ) 437 438 is_view_with_metadata_change = ( 439 "true" if cpp.name(f.func) in VIEW_FUNCTIONS_WITH_METADATA_CHANGE else "false" 440 ) 441 442 return SETUP_REPLAY_VIEW_IF_NOT_SUPPORT_AS_STRIDED_OR_VIEW_WITH_METADATA_CHANGE.substitute( 443 is_view_with_metadata_change=is_view_with_metadata_change, 444 replay_view_func=replay_view_func, 445 reverse_replay_view_func=reverse_replay_view_func, 446 ) 447 448 449def emit_view_body( 450 fn: NativeFunctionWithDifferentiabilityInfo, var: str 451) -> tuple[str, str]: 452 # See NOTE [ Autograd View Variables ] in variable.h for details. 453 f = fn.func 454 base_name = get_base_name(f) 455 view_info = get_view_info(f) 456 call = "" 457 differentiable_outputs = gen_differentiable_outputs(fn) 458 differentiable_output_vars = {r.name for r in differentiable_outputs} 459 if not isinstance(view_info, str): 460 raise TypeError( 461 f"The view info should be a string for {base_name}, but it is: {view_info}" 462 ) 463 if len(differentiable_output_vars) == 0: 464 # no output is differentiable (.indices() for SparseTensors for example) 465 rhs_value = ( 466 f"as_view({view_info}, {var}, " 467 f"/* is_bw_differentiable */ false, /* is_fw_differentiable */ false)" 468 ) 469 elif len(differentiable_output_vars) == 1: 470 # Single differentiable output (Tensor or Tensor[]) 471 return_info = differentiable_outputs[0] 472 # We only support simple Tensor or a TensorList for functions that return views 473 if not is_tensor_type(return_info.type) and not is_tensor_list_type( 474 return_info.type 475 ): 476 raise RuntimeError( 477 f"{base_name} that return differentiable views can only return Tensor or Tensor[]" 478 ) 479 480 # See Note [ View + Inplace detection] 481 def get_creation_meta_in_mode(original: str) -> str: 482 creation_meta_with_grad_mode = f"(at::GradMode::is_enabled() ? {original} : CreationMeta::NO_GRAD_MODE)" 483 return f"InferenceMode::is_enabled() ? CreationMeta::INFERENCE_MODE : {creation_meta_with_grad_mode}" 484 485 # Only allow rebasing of the history if we return a single Tensor 486 # If we are in a no grad block, raise a warning 487 # See NOTE [ View + Inplace detection ] for more details about this logic 488 if is_tensor_list_type(return_info.type): 489 creation_meta = get_creation_meta_in_mode("CreationMeta::MULTI_OUTPUT_NODE") 490 view_idx = "view_idx" 491 view_func = emit_view_func( 492 f, extract_bindings(f), view_idx=view_idx 493 ).strip() 494 as_view_call = ( 495 f"as_view(/* base */ {view_info}, /* output */ {var}[{view_idx}], " 496 "/* is_bw_differentiable */ true, /* is_fw_differentiable */ true, " 497 "/* view_func */ std::move(func), /* rev_view_func */ rev_func, " 498 f"/* creation_meta */ {creation_meta});" 499 ) 500 call += MULTI_OUTPUT_VIEW_ITERATION.substitute( 501 var=var, view_idx=view_idx, body=f"{view_func}\n{as_view_call}" 502 ) 503 rhs_value = f"std::move({var})" 504 else: 505 call += emit_view_func(f, extract_bindings(f), view_idx=None) 506 creation_meta = get_creation_meta_in_mode("CreationMeta::DEFAULT") 507 rhs_value = ( 508 f"as_view(/* base */ {view_info}, /* output */ {var}, /* is_bw_differentiable */ true, " 509 "/* is_fw_differentiable */ true, " 510 f"/* view_func */ std::move(func), /* rev_view_func */ rev_func, /* creation_meta */ {creation_meta})" 511 ) 512 else: 513 # This could be supported but we don't need it at the moment, so keeping things simple. 514 raise RuntimeError( 515 "Function that return multiple differentiable output " 516 "when at least one of them is view is not supported." 517 ) 518 return call, rhs_value 519 520 521def modifies_arguments(f: NativeFunction) -> bool: 522 return f.func.kind() in [SchemaKind.inplace, SchemaKind.out] 523 524 525@with_native_function_with_differentiability_info 526def emit_inplace_or_view_body(fn: NativeFunctionWithDifferentiabilityInfo) -> list[str]: 527 f = fn.func 528 inplace_view_body: list[str] = [] 529 530 dispatcher_sig = DispatcherSignature.from_schema(f.func) 531 dispatcher_exprs = dispatcher_sig.exprs() 532 533 # code-generated ADInplaceOrView kernels plumb and recompute dispatch keys directly through the kernel for performance. 534 # See Note [Plumbing Keys Through The Dispatcher] for details. 535 dispatch_key_set = "ks & c10::after_ADInplaceOrView_keyset" 536 redispatch_args = ", ".join([dispatch_key_set] + [a.expr for a in dispatcher_exprs]) 537 538 # Note that this calls the slow, dispatching variants of manual_cpp_binding ops. 539 # We could probably work harder to ensure that the fast variants are called instead, but the perf benefit would be minimal. 540 if modifies_arguments(f): # inplace op 541 inplace_view_body.append( 542 INPLACE_REDISPATCH.substitute( 543 unambiguous_name=f.func.name.unambiguous_name(), 544 unpacked_args=redispatch_args, 545 ) 546 ) 547 for r in cpp.return_names(f): 548 inplace_view_body.append(f"increment_version({r});") 549 else: 550 assert get_view_info(f) is not None 551 inplace_view_body.append( 552 VIEW_REDISPATCH.substitute( 553 assign_return_values="auto " + TMP_VAR + " = ", 554 unambiguous_name=f.func.name.unambiguous_name(), 555 unpacked_args=redispatch_args, 556 ) 557 ) 558 call, rhs_value = emit_view_body(fn, TMP_VAR) 559 inplace_view_body.append(call) 560 assert rhs_value is not None 561 inplace_view_body.append( 562 ASSIGN_RETURN_VALUE.substitute( 563 return_values=tie_return_values(f), rhs_value=rhs_value 564 ) 565 ) 566 if f.func.returns: 567 inplace_view_body.append(f"return {get_return_value(f)};") 568 return inplace_view_body 569 570 571@with_native_function 572def gen_formals(f: NativeFunction) -> str: 573 return ", ".join( 574 # code-generated autograd kernels plumb and recompute dispatch keys directly through the kernel for performance. 575 # See Note [Plumbing Keys Through The Dispatcher] for details. 576 ["c10::DispatchKeySet ks"] 577 + [ 578 f'{cpp.argument_type(a, binds="__placeholder__", symint=True).cpp_type()} {a.name}' 579 for a in f.func.schema_order_arguments() 580 ] 581 ) 582 583 584@with_native_function_with_differentiability_info 585def inplace_or_view_method_definition( 586 fn: NativeFunctionWithDifferentiabilityInfo, 587) -> str | None: 588 f = fn.func 589 if get_view_info(f) is None and ( 590 # For functions that modify their inputs but don't return them, 591 # we can't give them autograd support. 592 # See https://github.com/pytorch/pytorch/issues/53796 593 not modifies_arguments(f) 594 or len(f.func.returns) == 0 595 ): 596 return None 597 return METHOD_DEFINITION.substitute( 598 return_type=cpp.returns_type(f.func.returns, symint=True).cpp_type(), 599 type_wrapper_name=type_wrapper_name(f), 600 formals=gen_formals(f), 601 type_definition_body=emit_inplace_or_view_body(fn), 602 ) 603 604 605@with_native_function_with_differentiability_info 606def inplace_or_view_method_registration( 607 fn: NativeFunctionWithDifferentiabilityInfo, 608) -> str | None: 609 f = fn.func 610 if get_view_info(f) is None and ( 611 not modifies_arguments(f) or len(f.func.returns) == 0 612 ): 613 return None 614 return WRAPPER_REGISTRATION.substitute( 615 unqual_operator_name_with_overload=f.func.name, 616 type_wrapper_name=type_wrapper_name(f), 617 class_type="ADInplaceOrView", 618 ) 619 620 621def use_derived(fn: NativeFunctionWithDifferentiabilityInfo) -> bool: 622 f = fn.func 623 name = cpp.name(f.func) 624 return name not in MANUAL_AUTOGRAD and dispatch_strategy(fn) == "use_derived" 625 626 627def gen_inplace_or_view_type_env( 628 fn: NativeFunctionWithDifferentiabilityInfo, 629) -> dict[str, list[str]]: 630 definition = inplace_or_view_method_definition(fn) 631 registration = inplace_or_view_method_registration(fn) 632 633 return { 634 "ops_headers": ( 635 [f"#include <ATen/ops/{fn.func.root_name}_ops.h>"] 636 if definition is not None 637 else [] 638 ), 639 "inplace_or_view_method_definitions": [definition] 640 if definition is not None 641 else [], 642 "inplace_or_view_wrapper_registrations": [registration] 643 if registration is not None 644 else [], 645 } 646 647 648def gen_inplace_or_view_type( 649 out: str, 650 native_yaml_path: str, 651 tags_yaml_path: str, 652 fns_with_infos: list[NativeFunctionWithDifferentiabilityInfo], 653 template_path: str, 654) -> None: 655 # NOTE: see Note [Sharded File] at the top of the VariableType.cpp 656 # template regarding sharding of the generated files. 657 num_shards = 2 658 659 fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False) 660 fm.write_sharded( 661 "ADInplaceOrViewType.cpp", 662 [fn for fn in fns_with_infos if use_derived(fn)], 663 key_fn=lambda fn: fn.func.root_name, 664 base_env={ 665 "generated_comment": "@" 666 + f"generated from {fm.template_dir_for_comments()}/ADInplaceOrViewType.cpp", 667 }, 668 env_callable=gen_inplace_or_view_type_env, 669 num_shards=2, 670 sharded_keys={ 671 "ops_headers", 672 "inplace_or_view_method_definitions", 673 "inplace_or_view_wrapper_registrations", 674 }, 675 ) 676