xref: /aosp_15_r20/external/pytorch/torchgen/_autoheuristic/ah_tree.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1from typing import Any, Dict, List, Optional, Tuple
2
3import numpy as np
4from sklearn.tree import _tree  # type: ignore[import-untyped]
5
6
7class DecisionTreeNode:
8    def __init__(
9        self,
10        feature: Optional[str] = None,
11        threshold: Optional[float] = None,
12        left: Optional["DecisionTreeNode"] = None,
13        right: Optional["DecisionTreeNode"] = None,
14        class_probs: Any = None,
15        num_samples: int = 0,
16        node_id: int = 0,
17    ) -> None:
18        self.feature = feature
19        self.threshold = threshold
20        self.left = left
21        self.right = right
22        self.class_probs = class_probs
23        self.num_samples = num_samples
24        self.id = node_id
25
26    def is_leaf(self) -> bool:
27        return self.left is None or self.right is None
28
29
30class DecisionTree:
31    """
32    Custom decision tree implementation that mimics some of the sklearn API.
33    The purpose of this class it to be able to perform transformations, such as custom pruning, which
34    does not seem to be easy with sklearn.
35    """
36
37    def __init__(self, sklearn_tree: Any, feature_names: List[str]) -> None:
38        self.feature_names = feature_names
39        self.root = self._convert_sklearn_tree(sklearn_tree.tree_)
40        self.classes_: List[str] = sklearn_tree.classes_
41
42    def _convert_sklearn_tree(
43        self, sklearn_tree: Any, node_id: int = 0
44    ) -> DecisionTreeNode:
45        class_probs = sklearn_tree.value[node_id][0]
46        num_samples = sklearn_tree.n_node_samples[node_id]
47        if sklearn_tree.feature[node_id] != _tree.TREE_UNDEFINED:
48            feature_index = sklearn_tree.feature[node_id]
49            feature = self.feature_names[feature_index]
50            left = self._convert_sklearn_tree(
51                sklearn_tree, sklearn_tree.children_left[node_id]
52            )
53            right = self._convert_sklearn_tree(
54                sklearn_tree, sklearn_tree.children_right[node_id]
55            )
56            return DecisionTreeNode(
57                feature=feature,
58                threshold=sklearn_tree.threshold[node_id],
59                left=left,
60                right=right,
61                class_probs=class_probs,
62                num_samples=num_samples,
63                node_id=node_id,
64            )
65        else:
66            return DecisionTreeNode(
67                class_probs=class_probs, num_samples=num_samples, node_id=node_id
68            )
69
70    def prune(self, df: Any, target_col: str, k: int) -> None:
71        self.root = self._prune_tree(self.root, df, target_col, k)
72
73    def _prune_tree(
74        self, node: DecisionTreeNode, df: Any, target_col: str, k: int
75    ) -> DecisionTreeNode:
76        if node.is_leaf():
77            return node
78
79        left_df = df[df[node.feature] <= node.threshold]
80        right_df = df[df[node.feature] > node.threshold]
81
82        # number of unique classes in the left and right subtrees
83        left_counts = left_df[target_col].nunique()
84        right_counts = right_df[target_col].nunique()
85
86        # for ranking, we want to ensure that we return at least k classes, so if we have less than k classes in the
87        # left or right subtree, we remove the split and make this node a leaf node
88        if left_counts < k or right_counts < k:
89            return DecisionTreeNode(class_probs=node.class_probs)
90
91        assert node.left is not None, "expected left child to exist"
92        node.left = self._prune_tree(node.left, left_df, target_col, k)
93        assert node.right is not None, "expected right child to exist"
94        node.right = self._prune_tree(node.right, right_df, target_col, k)
95
96        return node
97
98    def to_dot(self) -> str:
99        dot = "digraph DecisionTree {\n"
100        dot += '    node [fontname="helvetica"];\n'
101        dot += '    edge [fontname="helvetica"];\n'
102        dot += self._node_to_dot(self.root)
103        dot += "}"
104        return dot
105
106    def _node_to_dot(
107        self, node: DecisionTreeNode, parent_id: int = 0, edge_label: str = ""
108    ) -> str:
109        if node is None:
110            return ""
111
112        node_id = id(node)
113
114        # Format class_probs array with line breaks
115        class_probs_str = self._format_class_probs_array(
116            node.class_probs, node.num_samples
117        )
118
119        if node.is_leaf():
120            label = class_probs_str
121            shape = "box"
122        else:
123            feature_name = f"{node.feature}"
124            label = f"{feature_name} <= {node.threshold:.2f}\\n{class_probs_str}"
125            shape = "oval"
126
127        dot = f'    {node_id} [label="{label}", shape={shape}];\n'
128
129        if parent_id != 0:
130            dot += f'    {parent_id} -> {node_id} [label="{edge_label}"];\n'
131
132        if not node.is_leaf():
133            assert node.left is not None, "expected left child to exist"
134            dot += self._node_to_dot(node.left, node_id, "<=")
135            assert node.right is not None, "expected right child to exist"
136            dot += self._node_to_dot(node.right, node_id, ">")
137
138        return dot
139
140    def _format_class_prob(self, num: float) -> str:
141        if num == 0:
142            return "0"
143        return f"{num:.2f}"
144
145    def _format_class_probs_array(
146        self, class_probs: Any, num_samples: int, max_per_line: int = 5
147    ) -> str:
148        # add line breaks to avoid very long lines
149        flat_class_probs = class_probs.flatten()
150        formatted = [self._format_class_prob(v) for v in flat_class_probs]
151        lines = [
152            formatted[i : i + max_per_line]
153            for i in range(0, len(formatted), max_per_line)
154        ]
155        return f"num_samples={num_samples}\\n" + "\\n".join(
156            [", ".join(line) for line in lines]
157        )
158
159    def predict(self, X: Any) -> Any:
160        predictions = [self._predict_single(x) for _, x in X.iterrows()]
161        return np.array(predictions)
162
163    def predict_proba(self, X: Any) -> Any:
164        return np.array([self._predict_proba_single(x) for _, x in X.iterrows()])
165
166    def _get_leaf(self, X: Any) -> DecisionTreeNode:
167        node = self.root
168        while not node.is_leaf():
169            if X[node.feature] <= node.threshold:
170                assert node.left is not None, "expected left child to exist"
171                node = node.left
172            else:
173                assert node.right is not None, "expected right child to exist"
174                node = node.right
175        return node
176
177    def _predict_single(self, x: Any) -> str:
178        node = self._get_leaf(x)
179        # map index to class name
180        return self.classes_[np.argmax(node.class_probs)]
181
182    def _predict_proba_single(self, x: Any) -> Any:
183        node = self._get_leaf(x)
184        return node.class_probs
185
186    def apply(self, X: Any) -> Any:
187        ids = [self._apply_single(x) for _, x in X.iterrows()]
188        return np.array(ids)
189
190    def _apply_single(self, x: Any) -> int:
191        node = self._get_leaf(x)
192        return node.id
193
194    def codegen(
195        self,
196        dummy_col_2_col_val: Dict[str, Tuple[str, Any]],
197        lines: List[str],
198        unsafe_leaves: List[int],
199    ) -> None:
200        # generates python code for the decision tree
201        def codegen_node(node: DecisionTreeNode, depth: int) -> None:
202            indent = "    " * (depth + 1)
203            if node.is_leaf():
204                lines.append(handle_leaf(node, indent, unsafe_leaves))
205            else:
206                name = node.feature
207                threshold = node.threshold
208                if name in dummy_col_2_col_val:
209                    (orig_name, value) = dummy_col_2_col_val[name]
210                    predicate = f"{indent}if str(context.get_value('{orig_name}')) != '{value}':"
211                    assert (
212                        threshold == 0.5
213                    ), f"expected threshold to be 0.5 but is {threshold}"
214                else:
215                    predicate = (
216                        f"{indent}if context.get_value('{name}') <= {threshold}:"
217                    )
218                lines.append(predicate)
219                assert node.left is not None, "expected left child to exist"
220                codegen_node(node.left, depth + 1)
221                lines.append(f"{indent}else:")
222                assert node.right is not None, "expected right child to exist"
223                codegen_node(node.right, depth + 1)
224
225        def handle_leaf(
226            node: DecisionTreeNode, indent: str, unsafe_leaves: List[int]
227        ) -> str:
228            """
229            This generates the code for a leaf node in the decision tree. If the leaf is unsafe, the learned heuristic
230            will return "unsure" (i.e. None).
231            """
232            if node.id in unsafe_leaves:
233                return f"{indent}return None"
234            class_probas = node.class_probs
235            return f"{indent}return {best_probas_and_indices(class_probas)}"
236
237        def best_probas_and_indices(class_probas: Any) -> str:
238            """
239            Given a list of tuples (proba, idx), this function returns a string in which the tuples are
240            sorted by proba in descending order. E.g.:
241            Given class_probas=[(0.3, 0), (0.5, 1), (0.2, 2)]
242            this function returns
243            "[(0.5, 1), (0.3, 0), (0.2, 2)]"
244            """
245            # we generate a list of tuples (proba, idx) sorted by proba in descending order
246            # idx is the index of a choice
247            # we only generate a tuple if proba > 0
248            probas_indices_sorted = sorted(
249                [
250                    (proba, index)
251                    for index, proba in enumerate(class_probas)
252                    if proba > 0
253                ],
254                key=lambda x: x[0],
255                reverse=True,
256            )
257            probas_indices_sorted_str = ", ".join(
258                f"({value:.3f}, {index})" for value, index in probas_indices_sorted
259            )
260            return f"[{probas_indices_sorted_str}]"
261
262        codegen_node(self.root, 1)
263