# Generates ViewFuncs.h/cpp # # NOTE: If any changes are being made to the ViewFunc codegen please also check # if updates are needed in torch/csrc/autograd/autograd_not_implemented_fallback.cpp # The fallback is expected to mimic this codegen, so we should keep the two in sync. from __future__ import annotations from typing import TYPE_CHECKING import torchgen.api.dispatcher as dispatcher from torchgen.api.translate import translate from torchgen.api.types import ( BaseCType, Binding, NamedCType, SymIntT, tensorT, VectorCType, ) from torchgen.code_template import CodeTemplate from torchgen.model import Argument, NativeFunction, OptionalType from torchgen.utils import FileManager from .gen_inplace_or_view_type import ( CALL_DISPATCH, extract_bindings, get_view_info, modifies_arguments, use_derived, ) if TYPE_CHECKING: from torchgen.api.autograd import NativeFunctionWithDifferentiabilityInfo FUNCTION_DECLARATION = CodeTemplate( """\ #define ${uppercase_op}_AVAILABLE struct ${op} : public ${superclass} { ${op}(${constructor_args}) ${initializer_list} {}; virtual ~${op}() override {}; virtual std::vector get_symints() const override; virtual size_t num_symints() const override; virtual std::vector get_tensors() const override; virtual size_t num_tensors() const override; virtual at::Tensor operator()(const at::Tensor&) const override; virtual std::unique_ptr clone_and_set( std::optional> = ::std::nullopt, std::optional> = ::std::nullopt) const override; protected: virtual void set_symints(std::vector) override; virtual void set_tensors(std::vector) override; private: ${state} }; """ ) FUNCTION_DEFINITION = CodeTemplate( """\ std::vector ${op}::get_symints() const { ${get_symints} } size_t ${op}::num_symints() const { return static_cast(${num_symints}); } void ${op}::set_symints(std::vector ${symints_vec}) { TORCH_INTERNAL_ASSERT(${symints_vec}.size() == num_symints()); ${set_symints} } std::vector ${op}::get_tensors() const { ${get_tensors} } size_t ${op}::num_tensors() const { return static_cast(${num_tensors}); } void ${op}::set_tensors(std::vector ${tensors_vec}) { TORCH_INTERNAL_ASSERT(${tensors_vec}.size() == num_tensors()); ${set_tensors} } at::Tensor ${op}::operator()(const at::Tensor& ${call_input_name}) const { return ${op_call}; } std::unique_ptr ${op}::clone_and_set( std::optional> ${symints_vec}, std::optional> ${tensors_vec}) const { auto output = std::make_unique<${op}>(${clone_args}); if (${symints_vec}.has_value()) { output->set_symints(std::move(*(${symints_vec}))); } if (${tensors_vec}.has_value()) { output->set_tensors(std::move(*(${tensors_vec}))); } return output; } """ ) # e.g. as_strided -> AsStridedViewFunc for camel case or # as_strided_view_func otherwise def view_func_name( f: NativeFunction, include_namespace: bool = False, camel_case: bool = True ) -> str: name = f.func.name.unambiguous_name() view_func_name = f"{name.replace('.', '_')}_view_func" if camel_case: is_private = view_func_name.startswith("_") view_func_name = "".join( [p.title() for p in view_func_name.replace(".", "_").split("_")] ) if is_private: # put the leading underscore back in view_func_name = f"_{view_func_name}" namespace = "torch::autograd::generated::" if include_namespace else "" return f"{namespace}{view_func_name}" def is_symint_or_tensor(arg: Argument) -> bool: return arg.type.is_tensor_like() or arg.type.is_symint_like() def remove_const_ref(binding: Binding) -> Binding: return Binding( name=binding.name, nctype=binding.nctype.remove_const_ref(), argument=binding.argument, default=binding.default, ) def returns_multi_tensor(fn: NativeFunction) -> bool: returns = fn.func.returns assert len(returns) == 1 returns_list_like = returns[0].type.is_list_like() is not None returns_tensor_like = returns[0].type.is_tensor_like() return returns_list_like and returns_tensor_like # Generates strings with logic for getting / setting state of a particular type. # # Args: # bindings (list): List of state bindings of interest (may be empty) # state_vec_type (NamedCType): Type of vector to either return or copy from # # Returns: # tuple: (list of getter logic strings, list of setter logic strings, string # with num items expression) def generate_state_getter_setter( bindings: list[Binding], state_vec_type: NamedCType, ) -> tuple[list[str], list[str], str]: getter_logic = [] setter_logic = [] state_vec = state_vec_type.name getter_logic.append(f"{state_vec_type.cpp_type()} {state_vec};") if len(bindings) > 0: setter_logic.append("auto i = 0;") num_exprs = [] for i, b in enumerate(bindings): assert isinstance(b.argument, Argument) if b.argument.type.is_list_like(): # Handle list-likes. num_expr = f"{b.name}.size()" num_exprs.append(num_expr) getter = f"{state_vec}.insert({state_vec}.end(), {b.name}.begin(), {b.name}.end());" setter = f"std::copy({state_vec}.begin() + i, {state_vec}.begin() + i + {b.name}.size(), {b.name}.begin());" elif isinstance(b.argument.type, OptionalType): # Handle optionals. num_expr = f"({b.name}.has_value() ? 1 : 0)" num_exprs.append(num_expr) conditional = f"if({b.name}.has_value())" getter = ( f"{conditional} {state_vec}.insert({state_vec}.end(), *({b.name}));" ) setter = f"{conditional} {b.name} = {state_vec}[i];" else: num_expr = "1" num_exprs.append(num_expr) getter = f"{state_vec}.push_back({b.name});" setter = f"{b.name} = {state_vec}[i];" getter_logic.append(getter) setter_logic.append(setter) if i < len(bindings) - 1: setter_logic.append(f"i += {num_expr};") # Reserve / assert based on the total number of items expression. num_items = "0" if len(num_exprs) == 0 else " + ".join(num_exprs) if len(bindings) > 0: getter_logic.insert(1, f"{state_vec}.reserve({num_items});") getter_logic.append(f"return {state_vec};") return getter_logic, setter_logic, num_items def process_function(fn: NativeFunction, template: CodeTemplate) -> str: bindings = extract_bindings(fn) non_self_bindings = [b for b in bindings if b.name != "self"] non_self_args = fn.func.arguments.flat_all[1:] non_self_value_bindings = [ dispatcher.argument(a, remove_non_owning_ref_types=True) for a in non_self_args ] # Generate constructor / clone args for the generated struct. constructor_args = [b.defn() for b in non_self_bindings] clone_args = [b.name for b in non_self_bindings] # Generate state variable declarations for the generated struct. state_variables = [ f"{remove_const_ref(b).defn()};" for b in non_self_value_bindings ] # Generate initializer list expressions for the generated struct. # allow_expensive_conversions=True because we need to store e.g. SymIntArrayRefs as # vectors. init_exprs = translate( non_self_bindings, non_self_value_bindings, allow_expensive_conversions=True ) initializers = [] for b, init_expr in zip(non_self_bindings, init_exprs): name = b.nctype.name assert isinstance(name, str) initializers.append(f"{name}({init_expr.expr})") # Generate call to underlying view op call_input_name = "input_base" op_call_args = [call_input_name, *(b.name for b in non_self_bindings)] op_call = CALL_DISPATCH.substitute( unambiguous_name=fn.func.name.unambiguous_name(), unpacked_args=op_call_args, ) # Multi-output views additionally require a view_idx for disambiguation. if returns_multi_tensor(fn): view_idx_name = "view_idx" view_idx_typename = "int64_t" view_idx_decl = f"{view_idx_typename} {view_idx_name}" constructor_args.append(view_idx_decl) clone_args.append(view_idx_name) state_variables.append(f"{view_idx_decl};") initializers.append(f"{view_idx_name}({view_idx_name})") op_call += f"[{view_idx_name}]" # Generate initializer list for the generated struct. initializer_list = f": {', '.join(initializers)}" if len(initializers) > 0 else "" # Generate getter / setter logic for any symints. symint_bindings = [ b for b in non_self_bindings if isinstance(b.argument, Argument) and b.argument.type.is_symint_like() ] symints_vec_type = NamedCType("symints", VectorCType(BaseCType(SymIntT))) get_symints, set_symints, num_symints = generate_state_getter_setter( symint_bindings, symints_vec_type ) # Generate getter / setter logic for any tensors. tensor_bindings = [ b for b in non_self_bindings if isinstance(b.argument, Argument) and b.argument.type.is_tensor_like() ] tensors_vec_type = NamedCType("tensors", VectorCType(BaseCType(tensorT))) get_tensors, set_tensors, num_tensors = generate_state_getter_setter( tensor_bindings, tensors_vec_type ) return template.substitute( op=view_func_name(fn), uppercase_op=view_func_name(fn, camel_case=False).upper(), superclass="torch::autograd::ViewFunc", initializer_list=initializer_list, state=state_variables, constructor_args=constructor_args, clone_args=clone_args, symints_vec=symints_vec_type.name, get_symints=get_symints, set_symints=set_symints, num_symints=num_symints, tensors_vec=tensors_vec_type.name, get_tensors=get_tensors, set_tensors=set_tensors, num_tensors=num_tensors, call_input_name=call_input_name, op_call=op_call, ) def gen_view_funcs( out: str, fns_with_infos: list[NativeFunctionWithDifferentiabilityInfo], template_path: str, ) -> None: # don't need the info parts, just the function fns = [fn.func for fn in fns_with_infos if use_derived(fn)] # only want out-of-place views view_fns = [ fn for fn in fns if get_view_info(fn) is not None and not modifies_arguments(fn) ] declarations = [process_function(fn, FUNCTION_DECLARATION) for fn in view_fns] definitions = [process_function(fn, FUNCTION_DEFINITION) for fn in view_fns] ops_headers = [f"#include " for fn in view_fns] file_basename = "ViewFuncs" fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False) for suffix in [".h", ".cpp"]: fname = file_basename + suffix fm.write_with_template( fname, fname, lambda: { "generated_comment": "@" + f"generated from {fm.template_dir_for_comments()}/" + fname, "view_func_declarations": declarations, "view_func_definitions": definitions, "ops_headers": ops_headers, }, )