1from torchgen.api.lazy import LazyArgument, LazyIrSchema 2from torchgen.api.types import OptionalCType 3 4 5def ts_lowering_body(schema: LazyIrSchema) -> str: 6 # for now, we just want one IR class decl and soon after also the method defs 7 # and we use the functional version not out/inplace. 8 emplace_arguments = [] 9 10 def get_value(arg: LazyArgument) -> str: 11 if isinstance(arg.lazy_type, OptionalCType): 12 return f"has_{arg.name} ? loctx->GetOutputOp(operand(i++)) : nullptr" 13 return "loctx->GetOutputOp(operand(i++))" 14 15 for arg in schema.positional_args: 16 if arg.is_lazy_value: 17 emplace_arguments.append(get_value(arg)) 18 continue 19 emplace_arguments.append(f'"{arg.name}", {arg.name}') 20 21 emplace_arguments_str = "\n ".join( 22 [f"arguments.emplace_back({a});" for a in emplace_arguments] 23 ) 24 emplace_kwarg_values = [ 25 f'"{arg.name}", {get_value(arg)}' for arg in schema.keyword_values 26 ] 27 emplace_kwarg_scalars = [ 28 f'"{arg.name}", {arg.name}' for arg in schema.keyword_scalars 29 ] 30 emplace_kwarguments = "\n ".join( 31 [ 32 f"kwarguments.emplace_back({a});" 33 for a in emplace_kwarg_values + emplace_kwarg_scalars 34 ] 35 ) 36 return f"""\ 37 std::vector<torch::jit::NamedValue> arguments; 38 std::vector<torch::jit::NamedValue> kwarguments; 39 arguments.reserve({len(emplace_arguments)}); 40 kwarguments.reserve({len(emplace_kwarg_values + emplace_kwarg_scalars)}); 41 size_t i = 0; 42 {emplace_arguments_str} 43 {emplace_kwarguments} 44 torch::lazy::TSOpVector {schema.aten_name}_out = torch::lazy::LowerTSBuiltin(function, op().op, arguments, kwarguments); 45 TORCH_CHECK_EQ({schema.aten_name}_out.size(), {len(schema.returns)}); 46 47 return {schema.aten_name}_out; 48""" 49