xref: /aosp_15_r20/external/executorch/exir/operator/convert.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-strict
8
9r"""
10Handle the following op convertions:
11- convert a functional op to an out variant op
12- convert an out variant op to a scratch op.
13
14We assume there is already a functionalization pass being done that removes aliases and inplace variants.
15
16For the to_out_variant convertion, The functional variant will be represented
17as qualified op name plus the overload name. The returned out variant constains
18the following information
19- the OpOverload for the out variant
20- the list of keyward arguments names that are out variables. There should be
21  at least one out variables. Some ops may also have multiple out variables,
22  e.g. aten::topk.values returns both values and indices for the topk elements.
23
24"""
25
26import dataclasses
27import logging
28from typing import Dict, Optional, Tuple
29
30import torch
31from torch._ops import OpOverload
32from torchgen.model import FunctionSchema, SchemaKind
33
34# cache the FunctionSchema so we don't need to parse everytime>
35# Use OpOverload as hash key. We can not use torch._C.FunctionSchema as key since
36# it's not hashable.
37_op_overload_to_schema_cache: Dict[OpOverload, FunctionSchema] = {}
38
39# Value type is Optional so we can cache None if an op does not have
40# out variant/scratch op. This way, we don't need to confuse the op not
41# existing case with cache miss.
42_func_to_out_variant_map: Dict[OpOverload, Optional[OpOverload]] = {}
43_out_variant_to_scratch_map: Dict[OpOverload, Optional[OpOverload]] = {}
44_mutable_to_out_variant_map: Dict[OpOverload, Optional[OpOverload]] = {}
45
46# We've found a functional and an out variant with the same name, but their
47# schemas mismatch. This map collects all of these cases and provides proper
48# error message to user. The key is an `OpOverload` of a functional variant.
49_schema_mismatch_map: Dict[OpOverload, Optional[FunctionSchema]] = {}
50
51
52def _pybind_schema_to_native_schema(
53    pybind_schema: torch._C.FunctionSchema,
54) -> Optional[FunctionSchema]:
55    """
56    We have 2 FunctionSchema definitions in python.
57    One is defined in torchgen (call it native FunctionSchema), another is a
58    pybind of c10::FunctionSchema (call it pybind FunctionSchema).
59    Because we want to leverage torchgen to handle out variant, we will
60    convert any pybind FunctionSchema to native FunctionSchema.
61    """
62    native_schema = None
63    try:
64        native_schema = FunctionSchema.parse(str(pybind_schema))
65    except (RuntimeError, AssertionError, ValueError):
66        # Need catch AssertionError since parsing prim ops like:
67        #   aten::to.prim_other(Tensor(a) self, bool non_blocking=False, bool copy=False) -> Tensor(a|b)
68        # cause an asertion error in torchgen when parsiong annotation 'a|b'.
69        # We should ignore it. Hopefully one day the C++ FunctionSchema parsing
70        # is 100% consistent with Python FunctionSchema parsing, then we don't need
71        # catch these exceptions any more.
72
73        # We also need catch ValueError for schema like:
74        #   aten::copy.Dict_str(Dict(str, t)(a) self) -> Dict(str, t)
75        # torchgen throws ValueError since it does not expect the type string
76        # containing commas. Ignore those schemas for now.
77        logging.debug(f"Fail to parse function schema: {str(pybind_schema)}")
78        # ignore failure and return None. There are some schemas defined as
79        # prim ops that can not be parsed by torchgen. E.g.:
80        #   https://www.fburl.com/code/1vvzhssa
81        # We should be safe to ignore them since PyE are not using these ops.
82    return native_schema
83
84
85def _get_overload_schema(op_overload: OpOverload) -> Optional[FunctionSchema]:
86    native_schema = _op_overload_to_schema_cache.get(op_overload)
87    if not native_schema:
88        native_schema = _pybind_schema_to_native_schema(op_overload._schema)
89        _op_overload_to_schema_cache[op_overload] = native_schema  # pyre-ignore
90    return native_schema
91
92
93def get_out_args_from_opoverload(op_overload: OpOverload) -> Tuple[str]:
94    return get_out_args_from_schema(_get_overload_schema(op_overload))  # pyre-ignore
95
96
97def get_out_args_from_schema(out_var_schema: FunctionSchema) -> Tuple[str]:
98    """
99    Assume the input is the schema for an out variant.
100    Return the name list of the out arguments.
101    """
102    assert (
103        out_var_schema.is_out_fn()
104    ), f"Expect an out variant, but get: {out_var_schema}"
105    return tuple(arg.name for arg in out_var_schema.arguments.out)
106
107
108def parse_qualified_opname(qualified_opname: str) -> Tuple[str, str]:
109    """
110    Given a qualified opname like aten::add, return a tuple for namespace
111    (aten here) and op name (add here)
112    """
113    ns_and_opname = qualified_opname.split("::")
114    if len(ns_and_opname) != 2:
115        raise RuntimeError(f"Invalid qualified_opname {qualified_opname}")
116    return tuple(ns_and_opname)
117
118
119def get_op_overload(qualified_opname: str, overload: str) -> OpOverload:
120    """
121    Arguments:
122        qualified_opname: string like {namespace}::{op name}
123        overload: the overload string of the op
124    """
125    ns, opname = parse_qualified_opname(qualified_opname)
126    if not overload:
127        overload = "default"
128    return getattr(getattr(getattr(torch.ops, ns), opname), overload)
129
130
131def schema_to_opoverload(schema: FunctionSchema) -> OpOverload:
132    qualified_name = str(schema.name.name)
133    overload = schema.name.overload_name
134    return get_op_overload(qualified_name, overload)
135
136
137def set_mapping_for_op(op: OpOverload) -> None:
138    """
139    op can either be a functional op, mutable op, or out variant op.
140    This method is only called if
141    1. either op is a functional op and it's missing in the _func_to_out_variant_map cache.
142    2. or op is a out variant op and it's missing in the _out_variant_to_scratch_map cache.
143
144    Setup entries in _func_to_out_variant_map and _out_variant_to_scratch_map for all ops sharing the same
145    op name as the passed in OpOverload.
146    """
147    native_schema = _pybind_schema_to_native_schema(op._schema)
148    # pyre-fixme[16]: `Optional` has no attribute `kind`.
149    assert native_schema.kind() in (
150        SchemaKind.functional,
151        SchemaKind.out,
152        SchemaKind.mutable,
153    )
154    assert not (
155        native_schema.kind() == SchemaKind.functional and op in _func_to_out_variant_map
156    )
157    assert not (
158        native_schema.kind() == SchemaKind.out and op in _out_variant_to_scratch_map
159    )
160    assert not (
161        native_schema.kind() == SchemaKind.mutable and op in _mutable_to_out_variant_map
162    )
163    qualified_opname = str(op._schema.name)
164
165    all_schemas = [
166        _pybind_schema_to_native_schema(pybind_schema)
167        for pybind_schema in torch._C._jit_get_schemas_for_operator(qualified_opname)
168    ]
169
170    # skip the schema that we can not be parsed by torchgen
171    all_schemas = [schema for schema in all_schemas if schema is not None]
172
173    group_by_signature: Dict[str, Dict[SchemaKind, FunctionSchema]] = {}
174
175    for schema in all_schemas:
176        signature = schema.signature()
177        # override the return type to an empty tuple. Otherwise,  for ops like
178        # aten.slice.Tensor_out that returns a Tensor list,
179        # the signature of the schema does not match the one for the functional
180        # op aten.slice.Tensor because of different return type.
181        # Schema for aten.slice.Tensor_out:
182        #   split.Tensor_out(Tensor(a -> *) self, int split_size, int dim=0, *, Tensor(a!)[] out) -> ()
183        # Schema for aten.slice.Tensor
184        #   split.Tensor(Tensor(a -> *) self, int split_size, int dim=0) -> Tensor(a)[]
185        # The reason of the above inconsistency is explained in: https://github.com/pytorch/pytorch/pull/76049
186        signature = dataclasses.replace(signature, returns=())
187
188        kind = schema.kind()
189        # pyre-fixme[6]: For 1st argument expected `str` but got `FunctionSchema`.
190        group_by_kind = group_by_signature.setdefault(signature, {})
191        assert (
192            kind not in group_by_kind
193        ), f"Schema of kind {kind} already exist for {schema}"
194        group_by_kind[kind] = schema
195
196    # add all the functional op -> out variant op pairs to the cache
197    for group_by_kind in group_by_signature.values():
198        func_op_schema = group_by_kind.get(SchemaKind.functional)
199        out_var_schema = group_by_kind.get(SchemaKind.out)
200        mutable_op_schema = group_by_kind.get(SchemaKind.mutable)
201        scratch_schema = group_by_kind.get(SchemaKind.scratch)
202
203        # update the map even if out_var_schema is None to cache the negative
204        # case
205        if func_op_schema:
206            _func_to_out_variant_map[schema_to_opoverload(func_op_schema)] = (
207                schema_to_opoverload(out_var_schema) if out_var_schema else None
208            )
209            # out variant schema missing from group_by_kind
210            if out_var_schema is None:
211                # find the out variant with a schema different than the functional variant
212                mismatched_out_schema: Optional[FunctionSchema] = next(
213                    (s for s in all_schemas if s.kind() == SchemaKind.out), None
214                )
215                _schema_mismatch_map[schema_to_opoverload(func_op_schema)] = (
216                    mismatched_out_schema
217                )
218
219        # update hte map even if scratch_schema is None to cache the negative
220        # case
221        if out_var_schema:
222            _out_variant_to_scratch_map[schema_to_opoverload(out_var_schema)] = (
223                schema_to_opoverload(scratch_schema) if scratch_schema else None
224            )
225        if mutable_op_schema:
226            _mutable_to_out_variant_map[schema_to_opoverload(mutable_op_schema)] = (
227                schema_to_opoverload(out_var_schema) if out_var_schema else None
228            )
229
230
231def to_out_variant(op_overload: OpOverload) -> Tuple[OpOverload, Tuple[str]]:
232    r"""
233    Convert the passed in OpOverload to its out variant. Raise an exception if
234    on return the op_overload is not guaranteed to be an out variant.
235
236    If a conversion is found, return the out variant OpOverload alongwith the name of out
237    arguments.
238    """
239    schema = _get_overload_schema(op_overload)
240    if schema.is_out_fn():  # pyre-ignore
241        return op_overload, get_out_args_from_schema(schema)  # pyre-ignore[6]
242
243    # should be a functionalish op here
244    assert (
245        schema.kind() == SchemaKind.functional  # pyre-ignore[16]
246        or schema.kind() == SchemaKind.mutable
247    ), f"Expect a functionalish op, but get {schema.kind()} {schema}"
248
249    if (
250        op_overload not in _func_to_out_variant_map
251        and op_overload not in _mutable_to_out_variant_map
252    ):
253        # setup out_var
254        set_mapping_for_op(op_overload)
255
256    if op_overload in _mutable_to_out_variant_map:
257        out_var = _mutable_to_out_variant_map[op_overload]
258    else:
259        out_var = _func_to_out_variant_map.get(op_overload)
260
261    if not out_var:
262        msg = f"Missing out variant for functional op: {schema} . Make sure you have loaded your custom operator library for compiler. E.g., custom_ops_generated_lib"
263        if op_overload in _schema_mismatch_map:
264            if _schema_mismatch_map[op_overload]:
265                msg += (
266                    f"\nFound an out variant for operator name {op_overload.name()} but its schema mismatched with functional op."
267                    f"\nfunctional op schema:\t{schema}"
268                    f"\nout variant op schema:\t{_schema_mismatch_map[op_overload]}"
269                )
270        raise RuntimeError(msg)
271
272    return out_var, get_out_args_from_opoverload(out_var)
273
274
275def to_scratch_op(op_overload: OpOverload) -> Optional[OpOverload]:
276    schema = _get_overload_schema(op_overload)
277
278    # If the op is not an out variant, then we must have ignored some failure in to_out_var
279    # pass. Return immediately rather than throwing an exception since the user must have ignores
280    # errors for some reason (e.g. desigin some special unit tests, or unblock new
281    # use cases).
282    if schema.kind() != SchemaKind.out:  # pyre-ignore
283        logging.debug(f"Expect an out variant op as input, got: {schema.kind()}")
284        return None
285
286    if op_overload not in _out_variant_to_scratch_map:
287        set_mapping_for_op(op_overload)
288    scratch_op = _out_variant_to_scratch_map.get(op_overload)
289
290    # scratch_op can be None
291    return scratch_op
292
293
294def is_out_variant(qualified_opname: str, overload: str) -> bool:
295    op_overload = get_op_overload(qualified_opname, overload)
296    schema = _get_overload_schema(op_overload)
297    if schema is None:
298        return False
299    return schema.is_out_fn()
300
301
302def is_inplace_variant(qualified_opname: str, overload: str) -> bool:
303    op_overload = get_op_overload(qualified_opname, overload)
304    schema = _get_overload_schema(op_overload)
305    if schema is None:
306        return False
307    return schema.kind() == SchemaKind.inplace
308