xref: /aosp_15_r20/external/pytorch/torchgen/selective_build/operator.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1from __future__ import annotations
2
3from dataclasses import dataclass
4
5
6# This class holds information about a single operator used to determine
7# the outcome of a selective/custom PyTorch build that doesn't include
8# registration code for all the supported operators. This is done to
9# reduce the size of the generated binary so that it can be deployed in
10# situations where binary size comes at a premium.
11#
12@dataclass(frozen=True)
13class SelectiveBuildOperator:
14    # The name of the operator. This includes the aten::, etc... prefix
15    # The operator name may or may not have the overload name. If this
16    # operator name does not specify an overload name, the way to determine
17    # if this entry refers to the family of operators with this base name
18    # or just the operator with this name is to look at the value of the
19    # 'include_all_overloads' flag in this class.
20    name: str
21
22    # True if this is a root operator (i.e. called directly from a
23    # TorchScript model, etc...). An operator is considered to be a
24    # root operator if it is called directly from any one of the models
25    # that this instance of the pytorch library was built for. Hence, it
26    # may not be a root operator in all of the models that are used in
27    # this instance of the pytorch library.
28    is_root_operator: bool
29
30    # Is this operator used for on-device training? If True, then we need to
31    # use the information to generate code in VariableType_N.cpp for registration
32    # of training related operators. Again, this is True if this operator
33    # is used for training in one or more models used by this instance of the
34    # pytorch library.
35    is_used_for_training: bool
36
37    # If True, it indicates that this operator instance (object) refers to an
38    # operator without the overload name and should apply to all overloads
39    # which have this operator name as the base name. This flag is applicable
40    # only for objects that have operator names without a DOT (period) character
41    # in them.
42    #
43    # Note: This flag is a temporary workaround to grandfather in the current
44    # static selective (custom) build mechanism, which largely ignores overload
45    # names when determining whether to select operators for registration
46    # purposes.
47    include_all_overloads: bool
48
49    # Debug Information at the operator level
50    _debug_info: tuple[str, ...] | None
51
52    @staticmethod
53    def from_yaml_dict(
54        op_name: str, op_info: dict[str, object]
55    ) -> SelectiveBuildOperator:
56        allowed_keys = {
57            "name",
58            "is_root_operator",
59            "is_used_for_training",
60            "include_all_overloads",
61            "debug_info",
62        }
63
64        if len(set(op_info.keys()) - allowed_keys) > 0:
65            raise Exception(  # noqa: TRY002
66                "Got unexpected top level keys: {}".format(
67                    ",".join(set(op_info.keys()) - allowed_keys),
68                )
69            )
70
71        if "name" in op_info:
72            assert op_name == op_info["name"]
73
74        is_root_operator = op_info.get("is_root_operator", True)
75        assert isinstance(is_root_operator, bool)
76
77        is_used_for_training = op_info.get("is_used_for_training", True)
78        assert isinstance(is_used_for_training, bool)
79
80        include_all_overloads = op_info.get("include_all_overloads", True)
81        assert isinstance(include_all_overloads, bool)
82
83        debug_info: tuple[str, ...] | None = None
84        if "debug_info" in op_info:
85            di_list = op_info["debug_info"]
86            assert isinstance(di_list, list)
87            debug_info = tuple(str(x) for x in di_list)
88
89        return SelectiveBuildOperator(
90            name=op_name,
91            is_root_operator=is_root_operator,
92            is_used_for_training=is_used_for_training,
93            include_all_overloads=include_all_overloads,
94            _debug_info=debug_info,
95        )
96
97    @staticmethod
98    def from_legacy_operator_name_without_overload(
99        name: str,
100    ) -> SelectiveBuildOperator:
101        return SelectiveBuildOperator(
102            name=name,
103            is_root_operator=True,
104            is_used_for_training=True,
105            include_all_overloads=True,
106            _debug_info=None,
107        )
108
109    def to_dict(self) -> dict[str, object]:
110        ret: dict[str, object] = {
111            "is_root_operator": self.is_root_operator,
112            "is_used_for_training": self.is_used_for_training,
113            "include_all_overloads": self.include_all_overloads,
114        }
115        if self._debug_info is not None:
116            ret["debug_info"] = self._debug_info
117
118        return ret
119
120
121def merge_debug_info(
122    lhs: tuple[str, ...] | None,
123    rhs: tuple[str, ...] | None,
124) -> tuple[str, ...] | None:
125    # Ensure that when merging, each entry shows up just once.
126    if lhs is None and rhs is None:
127        return None
128
129    return tuple(set((lhs or ()) + (rhs or ())))
130
131
132def combine_operators(
133    lhs: SelectiveBuildOperator, rhs: SelectiveBuildOperator
134) -> SelectiveBuildOperator:
135    if str(lhs.name) != str(rhs.name):
136        raise Exception(  # noqa: TRY002
137            f"Expected both arguments to have the same name, but got '{str(lhs.name)}' and '{str(rhs.name)}' instead"
138        )
139
140    return SelectiveBuildOperator(
141        name=lhs.name,
142        # Consider this operator to be a root operator if it is a
143        # root operator in any of the models used in this instance of
144        # the pytorch library.
145        is_root_operator=lhs.is_root_operator or rhs.is_root_operator,
146        # Consider this operator to be a training operator if it is
147        # an operator used for training in any of the models used
148        # in this instance of the pytorch library.
149        is_used_for_training=lhs.is_used_for_training or rhs.is_used_for_training,
150        include_all_overloads=lhs.include_all_overloads or rhs.include_all_overloads,
151        _debug_info=merge_debug_info(lhs._debug_info, rhs._debug_info),
152    )
153
154
155def merge_operator_dicts(
156    lhs: dict[str, SelectiveBuildOperator],
157    rhs: dict[str, SelectiveBuildOperator],
158) -> dict[str, SelectiveBuildOperator]:
159    operators: dict[str, SelectiveBuildOperator] = {}
160    for op_name, op in list(lhs.items()) + list(rhs.items()):
161        new_op = op
162        if op_name in operators:
163            new_op = combine_operators(operators[op_name], op)
164
165        operators[op_name] = new_op
166
167    return operators
168
169
170def strip_operator_overload_name(op_name: str) -> str:
171    return op_name.split(".")[0]
172