xref: /aosp_15_r20/external/pytorch/torchgen/selective_build/selector.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1from __future__ import annotations
2
3from collections import defaultdict
4from collections.abc import Iterable
5from dataclasses import dataclass
6from typing import TYPE_CHECKING
7
8import yaml
9
10from torchgen.selective_build.operator import (
11    merge_debug_info,
12    merge_operator_dicts,
13    SelectiveBuildOperator,
14    strip_operator_overload_name,
15)
16
17
18if TYPE_CHECKING:
19    from torchgen.model import NativeFunction
20
21
22# A SelectiveBuilder holds information extracted from the selective build
23# YAML specification.
24#
25# It includes information about the build's selectivity, the debug_info
26# associated with this selective build (opaque string), and the set of
27# operators that should be included in the build.
28#
29@dataclass(frozen=True)
30class SelectiveBuilder:
31    # If true, then the build is not selective, and includes all
32    # operators.
33    include_all_operators: bool
34
35    # Debug Information at the selective/custom build level.
36    _debug_info: tuple[str, ...] | None
37
38    # A dictionary of operator -> operator metadata.
39    operators: dict[str, SelectiveBuildOperator]
40
41    # A dictionary of selected kernel tags and dtypes. Typically a
42    # PyTorch Operator Kernel (function) may have many code paths
43    # that are specialized for many many Tensor dtypes, so it's not
44    # one per kernel function, but there could be many per kernel
45    # function. The tag isn't a kernel function name, but some fragment
46    # of the kernel function implementation itself.
47    kernel_metadata: dict[str, list[str]]
48
49    # ExecuTorch only. A dictionary of kernel tag -> list of (list of input
50    # dtypes for tensor-like input args).
51    # This is from selective.yaml
52    et_kernel_metadata: dict[str, list[str]]
53
54    # A set of all the custom torch bind classes used by the selected models
55    # Stored as a set internally to remove duplicates proactively, but written
56    # as a list to yamls
57    custom_classes: set[str]
58
59    # A set of all the build features used by the selected models
60    # Stored as a set internally to remove duplicates proactively, but written
61    # as a list to yamls
62    build_features: set[str]
63
64    # If true, then fragments for all dtypes for all kernel functions
65    # are included as well as all custom classes. This is typically set when any one of the
66    # operator lists is generated from a mechanism other than
67    # tracing based selective build.
68    include_all_non_op_selectives: bool
69
70    @staticmethod
71    def get_nop_selector() -> SelectiveBuilder:
72        return SelectiveBuilder.from_yaml_dict({"include_all_operators": True})
73
74    @staticmethod
75    def from_yaml_dict(data: dict[str, object]) -> SelectiveBuilder:
76        valid_top_level_keys = {
77            "include_all_non_op_selectives",
78            "include_all_operators",
79            "debug_info",
80            "operators",
81            "kernel_metadata",
82            "et_kernel_metadata",
83            "custom_classes",
84            "build_features",
85        }
86        top_level_keys = set(data.keys())
87        if len(top_level_keys - valid_top_level_keys) > 0:
88            raise Exception(  # noqa: TRY002
89                "Got unexpected top level keys: {}".format(
90                    ",".join(top_level_keys - valid_top_level_keys),
91                )
92            )
93        include_all_operators = data.get("include_all_operators", False)
94        assert isinstance(include_all_operators, bool)
95
96        debug_info = None
97        if "debug_info" in data:
98            di_list = data["debug_info"]
99            assert isinstance(di_list, list)
100
101            debug_info = tuple(str(x) for x in di_list)
102
103        operators = {}
104        operators_dict = data.get("operators", {})
105        assert isinstance(operators_dict, dict)
106
107        for k, v in operators_dict.items():
108            operators[k] = SelectiveBuildOperator.from_yaml_dict(k, v)
109
110        kernel_metadata = {}
111        kernel_metadata_dict = data.get("kernel_metadata", {})
112        assert isinstance(kernel_metadata_dict, dict)
113
114        for k, v in kernel_metadata_dict.items():
115            kernel_metadata[str(k)] = [str(dtype) for dtype in v]
116
117        et_kernel_metadata = data.get("et_kernel_metadata", {})
118        assert isinstance(et_kernel_metadata, dict)
119
120        custom_classes = data.get("custom_classes", [])
121        assert isinstance(custom_classes, Iterable)
122        custom_classes = set(custom_classes)
123
124        build_features = data.get("build_features", [])
125        assert isinstance(build_features, Iterable)
126        build_features = set(build_features)
127
128        include_all_non_op_selectives = data.get("include_all_non_op_selectives", False)
129        assert isinstance(include_all_non_op_selectives, bool)
130
131        return SelectiveBuilder(
132            include_all_operators,
133            debug_info,
134            operators,
135            kernel_metadata,
136            et_kernel_metadata,
137            custom_classes,  # type: ignore[arg-type]
138            build_features,  # type: ignore[arg-type]
139            include_all_non_op_selectives,
140        )
141
142    @staticmethod
143    def from_yaml_str(config_contents: str) -> SelectiveBuilder:
144        contents = yaml.safe_load(config_contents)
145        return SelectiveBuilder.from_yaml_dict(contents)
146
147    @staticmethod
148    def from_yaml_path(config_path: str) -> SelectiveBuilder:
149        with open(config_path) as f:
150            contents = yaml.safe_load(f)
151            return SelectiveBuilder.from_yaml_dict(contents)
152
153    @staticmethod
154    def from_legacy_op_registration_allow_list(
155        allow_list: set[str], is_root_operator: bool, is_used_for_training: bool
156    ) -> SelectiveBuilder:
157        operators = {}
158        for op in allow_list:
159            operators[op] = {
160                "name": op,
161                "is_root_operator": is_root_operator,
162                "is_used_for_training": is_used_for_training,
163                "include_all_overloads": True,
164            }
165        return SelectiveBuilder.from_yaml_dict(
166            {
167                "operators": operators,
168                "include_all_non_op_selectives": True,
169            }
170        )
171
172    def is_operator_selected(self, name: str) -> bool:
173        if self.include_all_operators:
174            return True
175
176        if name in self.operators:
177            return True
178        name = strip_operator_overload_name(name)
179        return name in self.operators and self.operators[name].include_all_overloads
180
181    def is_native_function_selected(self, func: NativeFunction) -> bool:
182        op_name = op_name_from_native_function(func)
183        return self.is_operator_selected(op_name)
184
185    def is_operator_selected_for_training(self, name: str) -> bool:
186        if not self.is_operator_selected(name):
187            return False
188        if self.include_all_operators:
189            return True
190
191        not_training_op = SelectiveBuildOperator(
192            name="",
193            is_root_operator=False,
194            is_used_for_training=False,
195            include_all_overloads=False,
196            _debug_info=None,
197        )
198        op = not_training_op
199        if name in self.operators:
200            op = self.operators[name]
201
202        name = strip_operator_overload_name(name)
203        base_op = not_training_op
204        if name in self.operators:
205            base_op = self.operators[name]
206
207        return op.is_used_for_training or (
208            base_op.include_all_overloads and base_op.is_used_for_training
209        )
210
211    def is_native_function_selected_for_training(self, func: NativeFunction) -> bool:
212        op_name = op_name_from_native_function(func)
213        return self.is_operator_selected_for_training(op_name)
214
215    def is_root_operator(self, name: str) -> bool:
216        if not self.is_operator_selected(name):
217            return False
218        if self.include_all_operators:
219            return True
220
221        if name in self.operators:
222            op: SelectiveBuildOperator = self.operators[name]
223            return op.is_root_operator
224        name = strip_operator_overload_name(name)
225        if name not in self.operators:
226            return False
227        base_op: SelectiveBuildOperator = self.operators[name]
228        return base_op.include_all_overloads and base_op.is_root_operator
229
230    def is_kernel_dtype_selected(self, kernel_tag: str, dtype: str) -> bool:
231        if self.include_all_operators or self.include_all_non_op_selectives:
232            return True
233
234        return (
235            kernel_tag in self.kernel_metadata
236            and dtype in self.kernel_metadata[kernel_tag]
237        )
238
239    def et_get_selected_kernels(self, op_name: str, kernel_key: list[str]) -> list[str]:
240        """
241        Return a list of kernel keys that cover the used ops
242        """
243        # If no kernel metadata, either it's implied by include_all_operators=True or the op is not used.
244        if op_name not in self.et_kernel_metadata:
245            return kernel_key if self.include_all_operators else []
246        # Otherwise, only return the specific kernel keys.
247
248        result_set = set()
249
250        for model_kernel_keys in self.et_kernel_metadata[op_name]:
251            key_found = False
252            for key in kernel_key:
253                # Don't compare the version for now
254                if (
255                    key != "default"
256                    and key.split("/")[1] == model_kernel_keys.split("/")[1]
257                ):
258                    result_set.add(key)
259                    key_found = True
260                    break
261            if not key_found:
262                if "default" not in kernel_key:
263                    raise Exception("Missing kernel for the model")  # noqa: TRY002
264                else:
265                    result_set.add("default")
266
267        return list(result_set)
268
269    def to_dict(self) -> dict[str, object]:
270        ret: dict[str, object] = {
271            "include_all_non_op_selectives": self.include_all_non_op_selectives,
272            "include_all_operators": self.include_all_operators,
273        }
274        operators = {}
275        for op_name, op in self.operators.items():
276            operators[op_name] = op.to_dict()
277        ret["operators"] = operators
278
279        if self._debug_info is not None:
280            ret["debug_info"] = sorted(self._debug_info)
281
282        ret["kernel_metadata"] = {
283            k: sorted(v) for (k, v) in self.kernel_metadata.items()
284        }
285
286        ret["et_kernel_metadata"] = self.et_kernel_metadata
287
288        ret["custom_classes"] = sorted(self.custom_classes)
289
290        ret["build_features"] = sorted(self.build_features)
291
292        return ret
293
294
295def merge_kernel_metadata(
296    lhs: dict[str, list[str]],
297    rhs: dict[str, list[str]],
298) -> dict[str, list[str]]:
299    kernel_metadata: dict[str, list[str]] = {}
300    for tag_name, dtypes in list(lhs.items()) + list(rhs.items()):
301        dtypes_copy = set(dtypes)
302        if tag_name in kernel_metadata:
303            dtypes_copy |= set(kernel_metadata[tag_name])
304
305        kernel_metadata[tag_name] = list(dtypes_copy)
306
307    return kernel_metadata
308
309
310def merge_et_kernel_metadata(
311    lhs: dict[str, list[str]],
312    rhs: dict[str, list[str]],
313) -> dict[str, list[str]]:
314    merge_et_kernel_metadata: dict[str, set[str]] = defaultdict(set)
315    for op in list(lhs.keys()) + list(rhs.keys()):
316        merge_et_kernel_metadata[op].update(lhs.get(op, []))
317        merge_et_kernel_metadata[op].update(rhs.get(op, []))
318
319    return {op: sorted(val) for op, val in merge_et_kernel_metadata.items()}
320
321
322def combine_selective_builders(
323    lhs: SelectiveBuilder, rhs: SelectiveBuilder
324) -> SelectiveBuilder:
325    include_all_operators = lhs.include_all_operators or rhs.include_all_operators
326    debug_info = merge_debug_info(lhs._debug_info, rhs._debug_info)
327    operators = merge_operator_dicts(lhs.operators, rhs.operators)
328    kernel_metadata = merge_kernel_metadata(lhs.kernel_metadata, rhs.kernel_metadata)
329    et_kernel_metadata = merge_et_kernel_metadata(
330        lhs.et_kernel_metadata, rhs.et_kernel_metadata
331    )
332    include_all_non_op_selectives = (
333        lhs.include_all_non_op_selectives or rhs.include_all_non_op_selectives
334    )
335    custom_classes = lhs.custom_classes.union(rhs.custom_classes)
336    build_features = lhs.build_features.union(rhs.build_features)
337    return SelectiveBuilder(
338        include_all_operators,
339        debug_info,
340        operators,
341        kernel_metadata,
342        et_kernel_metadata,
343        custom_classes,
344        build_features,
345        include_all_non_op_selectives,
346    )
347
348
349def op_name_from_native_function(f: NativeFunction) -> str:
350    # This was originally read from the 'operator_name_with_overload' field in the
351    # declaration dict, which was the part before the first '(' in 'schema_string'.
352    return f"{f.namespace}::{f.func.name}"
353