xref: /aosp_15_r20/external/pytorch/torchgen/dest/lazy_ts_lowering.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1from torchgen.api.lazy import LazyArgument, LazyIrSchema
2from torchgen.api.types import OptionalCType
3
4
5def ts_lowering_body(schema: LazyIrSchema) -> str:
6    # for now, we just want one IR class decl and soon after also the method defs
7    # and we use the functional version not out/inplace.
8    emplace_arguments = []
9
10    def get_value(arg: LazyArgument) -> str:
11        if isinstance(arg.lazy_type, OptionalCType):
12            return f"has_{arg.name} ? loctx->GetOutputOp(operand(i++)) : nullptr"
13        return "loctx->GetOutputOp(operand(i++))"
14
15    for arg in schema.positional_args:
16        if arg.is_lazy_value:
17            emplace_arguments.append(get_value(arg))
18            continue
19        emplace_arguments.append(f'"{arg.name}", {arg.name}')
20
21    emplace_arguments_str = "\n    ".join(
22        [f"arguments.emplace_back({a});" for a in emplace_arguments]
23    )
24    emplace_kwarg_values = [
25        f'"{arg.name}", {get_value(arg)}' for arg in schema.keyword_values
26    ]
27    emplace_kwarg_scalars = [
28        f'"{arg.name}", {arg.name}' for arg in schema.keyword_scalars
29    ]
30    emplace_kwarguments = "\n    ".join(
31        [
32            f"kwarguments.emplace_back({a});"
33            for a in emplace_kwarg_values + emplace_kwarg_scalars
34        ]
35    )
36    return f"""\
37    std::vector<torch::jit::NamedValue> arguments;
38    std::vector<torch::jit::NamedValue> kwarguments;
39    arguments.reserve({len(emplace_arguments)});
40    kwarguments.reserve({len(emplace_kwarg_values + emplace_kwarg_scalars)});
41    size_t i = 0;
42    {emplace_arguments_str}
43    {emplace_kwarguments}
44    torch::lazy::TSOpVector {schema.aten_name}_out = torch::lazy::LowerTSBuiltin(function, op().op, arguments, kwarguments);
45    TORCH_CHECK_EQ({schema.aten_name}_out.size(), {len(schema.returns)});
46
47    return {schema.aten_name}_out;
48"""
49